From 4ee0e8c3e9d1425cf6f5a835492f54ffa1101c7c Mon Sep 17 00:00:00 2001 From: marius-mather Date: Tue, 23 Sep 2025 13:20:25 +1000 Subject: [PATCH 01/16] Add a Platform model where we can define the admin roles for each platform --- db/models.py | 22 +++++++++++++++++++++- 1 file changed, 21 insertions(+), 1 deletion(-) diff --git a/db/models.py b/db/models.py index 781e6202..d0de7c42 100644 --- a/db/models.py +++ b/db/models.py @@ -115,12 +115,28 @@ def add_group_membership( return membership +class PlatformRoleLink(BaseModel, table=True): + platform_id: PlatformEnum = Field(primary_key=True, foreign_key="platform.id", sa_type=DbEnum(PlatformEnum, name="PlatformEnum")) + role_id: str = Field(primary_key=True, foreign_key="auth0role.id") + + +class Platform(BaseModel, table=True): + id: PlatformEnum = Field(primary_key=True, unique=True, sa_type=DbEnum(PlatformEnum, name="PlatformEnum")) + # Human-readable name for the platform + name: str = Field(unique=True) + admin_roles: list["Auth0Role"] = Relationship( + back_populates="admin_platforms", link_model=PlatformRoleLink, + ) + members: list["PlatformMembership"] = Relationship(back_populates="platform") + + class PlatformMembership(BaseModel, table=True): __table_args__ = ( UniqueConstraint("platform_id", "user_id", name="platform_user_id_platform_id"), ) id: uuid.UUID = Field(default_factory=uuid.uuid4, primary_key=True) - platform_id: PlatformEnum = Field(sa_type=DbEnum(PlatformEnum, name="PlatformEnum")) + platform_id: PlatformEnum = Field(foreign_key="platform.id", sa_type=DbEnum(PlatformEnum, name="PlatformEnum")) + platform: Platform = Relationship(back_populates="members") user_id: str = Field(foreign_key="biocommons_user.id") user: "BiocommonsUser" = Relationship( back_populates="platform_memberships", @@ -351,6 +367,9 @@ class Auth0Role(BaseModel, table=True): admin_groups: list["BiocommonsGroup"] = Relationship( back_populates="admin_roles", link_model=GroupRoleLink ) + admin_platforms: list["Platform"] = Relationship( + back_populates="admin_roles", link_model=PlatformRoleLink + ) @classmethod def get_or_create_by_id( @@ -423,6 +442,7 @@ def user_is_admin(self, user: SessionUser) -> bool: # Update all model references BiocommonsUser.model_rebuild() +Platform.model_rebuild() PlatformMembership.model_rebuild() PlatformMembershipHistory.model_rebuild() GroupMembership.model_rebuild() From 181bd36db47eb7ec7ae64c48d4c6eecc6a5fca59 Mon Sep 17 00:00:00 2001 From: marius-mather Date: Tue, 23 Sep 2025 13:21:10 +1000 Subject: [PATCH 02/16] Don't automatically set relationships in datagen models - doesn't work well with unique constraints on platform --- tests/db/datagen.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/db/datagen.py b/tests/db/datagen.py index 37ced59a..200c4986 100644 --- a/tests/db/datagen.py +++ b/tests/db/datagen.py @@ -19,16 +19,16 @@ def id(cls) -> str: class Auth0RoleFactory(SQLAlchemyFactory[Auth0Role]): - __set_relationships__ = True + __set_relationships__ = False class BiocommonsGroupFactory(SQLAlchemyFactory[BiocommonsGroup]): - __set_relationships__ = True + __set_relationships__ = False class GroupMembershipFactory(SQLAlchemyFactory[GroupMembership]): - __set_relationships__ = True + __set_relationships__ = False class PlatformMembershipFactory(SQLAlchemyFactory[PlatformMembership]): - __set_relationships__ = True + __set_relationships__ = False From 4b551f5740a0ad253685d1b124dcf3cb11a5a2d6 Mon Sep 17 00:00:00 2001 From: marius-mather Date: Tue, 23 Sep 2025 13:27:00 +1000 Subject: [PATCH 03/16] Make sure we set group/user ID when saving history --- db/models.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/db/models.py b/db/models.py index d0de7c42..8ffcb915 100644 --- a/db/models.py +++ b/db/models.py @@ -278,8 +278,8 @@ def save_history( session.flush() history = GroupMembershipHistory( - group=self.group, - user=self.user, + group_id=self.group_id, + user_id=self.user_id, approval_status=self.approval_status, updated_at=self.updated_at, updated_by=self.updated_by, From b71d4df62bdee5d6f4c55b46d00c6925811f214b Mon Sep 17 00:00:00 2001 From: marius-mather Date: Tue, 23 Sep 2025 13:27:15 +1000 Subject: [PATCH 04/16] Fix tests of GroupMembership --- tests/db/test_models.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/tests/db/test_models.py b/tests/db/test_models.py index 20d02080..5485571d 100644 --- a/tests/db/test_models.py +++ b/tests/db/test_models.py @@ -324,8 +324,9 @@ def test_group_membership_grant_auth0_role_not_approved(status, test_auth0_clien membership_request.grant_auth0_role(test_auth0_client) -def test_group_membership_save_with_history(test_db_session): - membership = GroupMembershipFactory.build() +def test_group_membership_save_with_history(test_db_session, persistent_factories): + group = BiocommonsGroupFactory.create_sync() + membership = GroupMembershipFactory.build(group_id=group.group_id) membership.save(test_db_session, commit=True) test_db_session.refresh(membership) assert membership.id is not None @@ -339,7 +340,8 @@ def test_group_membership_save_with_history(test_db_session): def test_group_membership_save_and_commit_history(test_db_session, persistent_factories): - membership = GroupMembershipFactory.create_sync() + group = BiocommonsGroupFactory.create_sync() + membership = GroupMembershipFactory.build(group_id=group.group_id) membership.save_history(test_db_session, commit=True) test_db_session.refresh(membership) assert membership.id is not None From 0ffb735b4ffb5cad6829e879462f0ee31cbd40ac Mon Sep 17 00:00:00 2001 From: marius-mather Date: Tue, 23 Sep 2025 13:53:48 +1000 Subject: [PATCH 05/16] Add factory for creating platforms --- tests/conftest.py | 2 ++ tests/db/datagen.py | 5 +++++ 2 files changed, 7 insertions(+) diff --git a/tests/conftest.py b/tests/conftest.py index d7e9fe17..61d8dcd4 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -23,6 +23,7 @@ BiocommonsGroupFactory, BiocommonsUserFactory, GroupMembershipFactory, + PlatformFactory, PlatformMembershipFactory, ) @@ -252,6 +253,7 @@ def persistent_factories(test_db_session): BiocommonsGroupFactory, BiocommonsUserFactory, GroupMembershipFactory, + PlatformFactory, PlatformMembershipFactory, ] for factory in factories: diff --git a/tests/db/datagen.py b/tests/db/datagen.py index 200c4986..ae7b7fee 100644 --- a/tests/db/datagen.py +++ b/tests/db/datagen.py @@ -5,6 +5,7 @@ BiocommonsGroup, BiocommonsUser, GroupMembership, + Platform, PlatformMembership, ) from tests.datagen import random_auth0_id @@ -32,3 +33,7 @@ class GroupMembershipFactory(SQLAlchemyFactory[GroupMembership]): class PlatformMembershipFactory(SQLAlchemyFactory[PlatformMembership]): __set_relationships__ = False + + +class PlatformFactory(SQLAlchemyFactory[Platform]): + __set_relationships__ = False From 062f353a023640a36ca247f91d63193895ff69a7 Mon Sep 17 00:00:00 2001 From: marius-mather Date: Tue, 23 Sep 2025 14:04:06 +1000 Subject: [PATCH 06/16] Update test of platform membership creation to test platform is associated --- tests/db/test_models.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/tests/db/test_models.py b/tests/db/test_models.py index 5485571d..312c62ca 100644 --- a/tests/db/test_models.py +++ b/tests/db/test_models.py @@ -28,6 +28,7 @@ BiocommonsGroupFactory, BiocommonsUserFactory, GroupMembershipFactory, + PlatformFactory, ) FROZEN_TIME = datetime(2025, 1, 1, 12, 0, 0) @@ -96,13 +97,15 @@ def test_get_or_create_biocommons_user_from_auth0(test_db_session, mock_auth0_cl assert user.email == user_data.email assert user.username == user_data.username + def test_create_platform_membership(test_db_session, persistent_factories, frozen_time): """ Test creating a platform membership model """ user = BiocommonsUserFactory.create_sync(platform_memberships=[]) + platform = PlatformFactory.create_sync(id=PlatformEnum.GALAXY) membership = PlatformMembership( - platform_id=PlatformEnum.GALAXY, + platform_id=platform.id, user_id=user.id, approval_status=ApprovalStatusEnum.APPROVED, updated_by_id=None @@ -111,8 +114,10 @@ def test_create_platform_membership(test_db_session, persistent_factories, froze test_db_session.commit() test_db_session.refresh(membership) assert membership.user_id == user.id + assert membership.user == user assert membership.approval_status == ApprovalStatusEnum.APPROVED assert membership.platform_id == "galaxy" + assert membership.platform.id == "galaxy" assert membership.updated_at == FROZEN_TIME From b610934ba776c25e8186adc907bed128ee1e7812 Mon Sep 17 00:00:00 2001 From: marius-mather Date: Tue, 23 Sep 2025 14:11:13 +1000 Subject: [PATCH 07/16] Test platform creation --- tests/db/test_models.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/tests/db/test_models.py b/tests/db/test_models.py index 312c62ca..2e91652a 100644 --- a/tests/db/test_models.py +++ b/tests/db/test_models.py @@ -17,6 +17,7 @@ BiocommonsUser, GroupMembership, GroupMembershipHistory, + Platform, PlatformEnum, PlatformMembership, PlatformMembershipHistory, @@ -98,6 +99,18 @@ def test_get_or_create_biocommons_user_from_auth0(test_db_session, mock_auth0_cl assert user.username == user_data.username +@pytest.mark.parametrize("platform_id", list(PlatformEnum)) +def test_create_platform(platform_id, test_db_session, persistent_factories): + admin_role = Auth0RoleFactory.create_sync() + platform = Platform( + id=platform_id, + name=f"Platform {platform_id}", + admin_roles=[admin_role] + ) + test_db_session.commit() + assert platform.id == platform_id + + def test_create_platform_membership(test_db_session, persistent_factories, frozen_time): """ Test creating a platform membership model From f9264f38616a6daefac17f1dae953291739bd2df Mon Sep 17 00:00:00 2001 From: marius-mather Date: Tue, 23 Sep 2025 14:31:09 +1000 Subject: [PATCH 08/16] Test platform IDs are unique --- tests/db/test_models.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/tests/db/test_models.py b/tests/db/test_models.py index 2e91652a..f9c7b847 100644 --- a/tests/db/test_models.py +++ b/tests/db/test_models.py @@ -111,6 +111,16 @@ def test_create_platform(platform_id, test_db_session, persistent_factories): assert platform.id == platform_id +def test_create_platform_unique_id(test_db_session): + platform = Platform(id=PlatformEnum.GALAXY, name="Galaxy", admin_roles=[]) + test_db_session.add(platform) + test_db_session.commit() + with pytest.raises(IntegrityError): + platform = Platform(id=PlatformEnum.GALAXY, name="Galaxy Duplicate", admin_roles=[]) + test_db_session.add(platform) + test_db_session.commit() + + def test_create_platform_membership(test_db_session, persistent_factories, frozen_time): """ Test creating a platform membership model From 7873f4054fecac15ddaf3637757d3f513e266427 Mon Sep 17 00:00:00 2001 From: marius-mather Date: Wed, 24 Sep 2025 09:53:08 +1000 Subject: [PATCH 09/16] Test getting admin platforms --- routers/user.py | 20 +++++++++++++++++++- tests/test_user.py | 23 +++++++++++++++++++++-- 2 files changed, 40 insertions(+), 3 deletions(-) diff --git a/routers/user.py b/routers/user.py index e4bad50e..94b7acfc 100644 --- a/routers/user.py +++ b/routers/user.py @@ -10,7 +10,7 @@ from auth.management import get_management_token from auth.validator import get_current_user from config import Settings, get_settings -from db.models import GroupMembership, PlatformMembership +from db.models import Auth0Role, GroupMembership, Platform, PlatformMembership from db.setup import get_db_session from db.types import ApprovalStatusEnum from schemas.biocommons import Auth0UserData @@ -155,6 +155,24 @@ async def get_pending_platforms( return db_session.exec(query).all() +@router.get( + "/platforms/admin-roles", + description="Get platforms for which the current user has admin privileges.", +) +async def get_admin_platforms( + user: Annotated[SessionUser, Depends(get_current_user)], + db_session: Annotated[Session, Depends(get_db_session)], +): + """Get platforms for which the current user has admin privileges.""" + user_roles = user.access_token.biocommons_roles + query = ( + select(Platform) + .join(Platform.admin_roles) + .where(Auth0Role.name.in_(user_roles)) + ) + return db_session.exec(query).all() + + @router.get("/groups", response_model=list[GroupMembershipData],) async def get_groups( diff --git a/tests/test_user.py b/tests/test_user.py index 23211deb..757da95f 100644 --- a/tests/test_user.py +++ b/tests/test_user.py @@ -8,9 +8,11 @@ SessionUserFactory, ) from tests.db.datagen import ( + Auth0RoleFactory, BiocommonsGroupFactory, BiocommonsUserFactory, GroupMembershipFactory, + PlatformFactory, PlatformMembershipFactory, ) @@ -111,11 +113,11 @@ def test_check_is_admin_without_authentication(test_client): assert response.status_code == 401 -def _act_as_user(mocker, db_user): +def _act_as_user(mocker, db_user, roles: list[str] = None): """ Set up mocks so that the test client authenticates as the given user """ - access_token = AccessTokenPayloadFactory.build(sub=db_user.id) + access_token = AccessTokenPayloadFactory.build(sub=db_user.id, biocommons_roles=roles or []) auth0_user = SessionUserFactory.build(access_token=access_token) mocker.patch("auth.validator.verify_jwt", return_value=access_token) mocker.patch("auth.validator.get_current_user", return_value=auth0_user) @@ -239,3 +241,20 @@ def test_get_all_pending(test_client, test_db_session, mocker, persistent_factor assert len(data["platforms"]) == 1 platform_ids = [platform["platform_id"] for platform in data["platforms"]] assert pending_platform.platform_id in platform_ids + + +def test_get_admin_platforms(test_client, test_db_session, mocker, persistent_factories): + """Test that endpoint returns list of platforms the user is an admin for""" + admin_role = Auth0RoleFactory.create_sync(name="Admin") + other_platform_role = Auth0RoleFactory.create_sync(name="Other Platform Role") + user = BiocommonsUserFactory.create_sync() + valid_platform = PlatformFactory.create_sync(id=PlatformEnum.GALAXY, admin_roles=[admin_role]) + invalid_platform = PlatformFactory.create_sync(id=PlatformEnum.BPA_DATA_PORTAL, admin_roles=[other_platform_role]) + test_db_session.flush() + _act_as_user(mocker, user, roles=[admin_role.name]) + response = test_client.get("/me/platforms/admin-roles", headers={"Authorization": "Bearer valid_token"}) + assert response.status_code == 200 + data = response.json() + assert data[0] == valid_platform.model_dump(mode="json") + returned_ids = [p["id"] for p in data] + assert invalid_platform.id not in returned_ids From 58fda694ba4d0978389446b98594d5611ad341ab Mon Sep 17 00:00:00 2001 From: marius-mather Date: Wed, 24 Sep 2025 15:34:42 +1000 Subject: [PATCH 10/16] Add migrations for updated models --- .../08a3d0593418_platform_constraints.py | 31 ++++++++++ .../versions/575a146957f2_platform_model.py | 56 +++++++++++++++++++ 2 files changed, 87 insertions(+) create mode 100644 migrations/versions/08a3d0593418_platform_constraints.py create mode 100644 migrations/versions/575a146957f2_platform_model.py diff --git a/migrations/versions/08a3d0593418_platform_constraints.py b/migrations/versions/08a3d0593418_platform_constraints.py new file mode 100644 index 00000000..624a67b2 --- /dev/null +++ b/migrations/versions/08a3d0593418_platform_constraints.py @@ -0,0 +1,31 @@ +"""platform_constraints + +Revision ID: 08a3d0593418 +Revises: 575a146957f2 +Create Date: 2025-09-24 10:38:24.506817 + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa +import sqlmodel + + +# revision identifiers, used by Alembic. +revision: str = '08a3d0593418' +down_revision: Union[str, None] = '575a146957f2' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.create_unique_constraint(op.f('uq_platform_id'), 'platform', ['id']) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_constraint(op.f('uq_platform_id'), 'platform', type_='unique') + # ### end Alembic commands ### diff --git a/migrations/versions/575a146957f2_platform_model.py b/migrations/versions/575a146957f2_platform_model.py new file mode 100644 index 00000000..526cf9de --- /dev/null +++ b/migrations/versions/575a146957f2_platform_model.py @@ -0,0 +1,56 @@ +"""platform_model + +Revision ID: 575a146957f2 +Revises: 1546c07b9d78 +Create Date: 2025-09-24 10:07:01.958231 + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa +import sqlmodel + + +# revision identifiers, used by Alembic. +revision: str = '575a146957f2' +down_revision: Union[str, None] = '1546c07b9d78' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # NOTE: alembic doesn't automatically add new enum values to existing types + op.execute('ALTER TYPE "PlatformEnum" ADD VALUE \'SBP\'') + # ### commands auto generated by Alembic - please adjust! ### + op.create_table('platform', + sa.Column('id', sa.Enum('GALAXY', 'BPA_DATA_PORTAL', 'SBP', name='PlatformEnum'), nullable=False), + sa.Column('name', sqlmodel.sql.sqltypes.AutoString(), nullable=False), + sa.PrimaryKeyConstraint('id', name=op.f('pk_platform')), + sa.UniqueConstraint('id', name=op.f('uq_platform_id')), + sa.UniqueConstraint('name', name=op.f('uq_platform_name')) + ) + + # Insert the platform records that might be referenced by existing platformmembership records + op.execute("INSERT INTO platform (id, name) VALUES ('GALAXY', 'Galaxy Australia')") + op.execute("INSERT INTO platform (id, name) VALUES ('BPA_DATA_PORTAL', 'Bioplatforms Australia Data Portal')") + op.execute("INSERT INTO platform (id, name) VALUES ('SBP', 'Structural Biology Platform')") + + op.create_table('platformrolelink', + sa.Column('platform_id', sa.Enum('GALAXY', 'BPA_DATA_PORTAL', 'SBP', name='PlatformEnum'), nullable=False), + sa.Column('role_id', sqlmodel.sql.sqltypes.AutoString(), nullable=False), + sa.ForeignKeyConstraint(['platform_id'], ['platform.id'], name=op.f('fk_platformrolelink_platform_id_platform')), + sa.ForeignKeyConstraint(['role_id'], ['auth0role.id'], name=op.f('fk_platformrolelink_role_id_auth0role')), + sa.PrimaryKeyConstraint('platform_id', 'role_id', name=op.f('pk_platformrolelink')) + ) + # Create foreign key constraint for platformmembership.platform_id + op.create_foreign_key(op.f('fk_platformmembership_platform_id_platform'), 'platformmembership', 'platform', ['platform_id'], ['id']) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_constraint(op.f('fk_platformmembership_platform_id_platform'), 'platformmembership', type_='foreignkey') + op.drop_table('platformrolelink') + op.drop_table('platform') + # ### end Alembic commands ### From f323d6347efa1ce135eeadb5234e9cab54acd581 Mon Sep 17 00:00:00 2001 From: marius-mather Date: Wed, 24 Sep 2025 15:35:18 +1000 Subject: [PATCH 11/16] Add PlatformAdmin to backend DB admin --- db/admin.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/db/admin.py b/db/admin.py index 21468cf5..4817c7f2 100644 --- a/db/admin.py +++ b/db/admin.py @@ -16,6 +16,7 @@ BiocommonsUser, GroupMembership, GroupMembershipHistory, + Platform, PlatformMembership, PlatformMembershipHistory, ) @@ -130,6 +131,17 @@ class GroupMembershipHistoryAdmin(ModelView, model=GroupMembershipHistory): column_default_sort = ("updated_at", True) +class PlatformAdmin(ModelView, model=Platform): + can_edit = True + can_create = True + can_delete = True + form_include_pk = True + column_list = ["id", "name", "admin_roles"] + column_details_list = ["id", "name", "admin_roles", "members"] + column_default_sort = ("id", True) + + + class PlatformMembershipAdmin(ModelView, model=PlatformMembership): can_edit = False can_create = False @@ -171,6 +183,7 @@ class DatabaseAdmin: Auth0RoleAdmin, GroupMembershipAdmin, GroupMembershipHistoryAdmin, + PlatformAdmin, PlatformMembershipAdmin, PlatformMembershipHistoryAdmin, ) From 00e811566e70625e26800ac15b67fb26bcad6f06 Mon Sep 17 00:00:00 2001 From: marius-mather Date: Fri, 26 Sep 2025 11:05:59 +1000 Subject: [PATCH 12/16] Update get_users endpoint to only return users for platforms the admin can access --- routers/admin.py | 26 ++++++++++++++++++++++---- 1 file changed, 22 insertions(+), 4 deletions(-) diff --git a/routers/admin.py b/routers/admin.py index b2722673..69001a0b 100644 --- a/routers/admin.py +++ b/routers/admin.py @@ -5,21 +5,24 @@ from fastapi import APIRouter, Depends, HTTPException, Path from fastapi.params import Query from pydantic import BaseModel, Field, ValidationError -from sqlalchemy import false, func, or_ +from sqlalchemy import alias, false, func, or_ from sqlmodel import Session, select -from auth.validator import user_is_admin +from auth.validator import get_current_user, user_is_admin from auth0.client import Auth0Client, get_auth0_client from db.models import ( + Auth0Role, BiocommonsGroup, BiocommonsUser, GroupMembership, + Platform, PlatformEnum, PlatformMembership, ) from db.setup import get_db_session from db.types import ApprovalStatusEnum, GroupEnum from schemas.biocommons import Auth0UserDataWithMemberships +from schemas.user import SessionUser logger = logging.getLogger('uvicorn.error') @@ -105,7 +108,8 @@ def get_filter_options(): @router.get("/users", response_model=list[BiocommonsUserResponse]) -def get_users(db_session: Annotated[Session, Depends(get_db_session)], +def get_users(admin_user: Annotated[SessionUser, Depends(get_current_user)], + db_session: Annotated[Session, Depends(get_db_session)], pagination: Annotated[PaginationParams, Depends(get_pagination_params)], filter_by: str = Query(None, description="Filter users by group ('tsi', 'bpa_galaxy') or platform ('galaxy', 'bpa_data_portal')"), search: str = Query(None, description="Search users by username or email")): @@ -118,7 +122,21 @@ def get_users(db_session: Annotated[Session, Depends(get_db_session)], - Platform names: 'galaxy', 'bpa_data_portal' search: Optional search parameter for username or email """ - base_query = select(BiocommonsUser) + admin_roles = admin_user.access_token.biocommons_roles + # Base query with platform access filtering built-in + allowed_platforms_subquery = ( + select(Platform.id) + .join(Platform.admin_roles) + .where(Auth0Role.name.in_(admin_roles)) + ).alias("allowed_platforms") + # Need an alias or SQLAlchemy complains about duplicate column names + pm = alias(PlatformMembership, name="pm") + base_query = ( + select(BiocommonsUser) + .join(pm, BiocommonsUser.id == pm.c.user_id) + .where(pm.c.platform_id.in_(allowed_platforms_subquery)) + .distinct() + ) if filter_by: if filter_by in GROUP_MAPPING: From 6730fd62e40cf3465472cd3ad67ea3466da29e51 Mon Sep 17 00:00:00 2001 From: marius-mather Date: Fri, 26 Sep 2025 11:14:48 +1000 Subject: [PATCH 13/16] Update tests of get_users to include platform-specific users --- tests/test_admin.py | 272 +++++++++++++++++++++++++------------------- 1 file changed, 153 insertions(+), 119 deletions(-) diff --git a/tests/test_admin.py b/tests/test_admin.py index 84b087c8..00efd221 100644 --- a/tests/test_admin.py +++ b/tests/test_admin.py @@ -4,11 +4,13 @@ import pytest from fastapi import HTTPException from freezegun import freeze_time +from sqlmodel import Session from auth.management import get_management_token from auth.validator import get_current_user, user_is_admin from auth0.client import Auth0Client -from db.types import ApprovalStatusEnum, PlatformEnum +from db.models import BiocommonsGroup +from db.types import ApprovalStatusEnum, GroupEnum, PlatformEnum from main import app from routers.admin import PaginationParams from tests.datagen import ( @@ -18,15 +20,29 @@ SessionUserFactory, ) from tests.db.datagen import ( + Auth0RoleFactory, BiocommonsGroupFactory, BiocommonsUserFactory, GroupMembershipFactory, + PlatformFactory, PlatformMembershipFactory, ) FROZEN_TIME = datetime(2025, 1, 1, 12, 0, 0) +@pytest.fixture +def dummy_platform(persistent_factories): + """ + Set up a Galaxy platform with the admin role set to "Admin" + """ + admin_role = Auth0RoleFactory.create_sync(name="Admin") + return PlatformFactory.create_sync( + id=PlatformEnum.GALAXY, + admin_roles=[admin_role] + ) + + @pytest.fixture def frozen_time(): """ @@ -71,24 +87,54 @@ def test_user_is_admin_nonadmin_user(mock_settings): user_is_admin(current_user=user, settings=mock_settings) -def test_get_users(test_client, as_admin_user, mock_auth0_client, test_db_session): - # Create some test users in the database - db_users = BiocommonsUserFactory.batch(3) +def _create_user_with_platform_membership(db_session: Session, platform_id: PlatformEnum, **kwargs): + user = BiocommonsUserFactory.build(**kwargs) + membership = PlatformMembershipFactory.create_sync( + platform_id=platform_id, + user_id=user.id, + approval_status=ApprovalStatusEnum.APPROVED, + ) + user.platform_memberships.append(membership) + db_session.add(user) + db_session.commit() + return user + + +def _users_with_platform_membership(n: int, db_session: Session, platform_id: PlatformEnum): + db_users = BiocommonsUserFactory.batch(n) for user in db_users: - test_db_session.add(user) - test_db_session.commit() + membership = PlatformMembershipFactory.create_sync( + platform_id=platform_id, + user_id=user.id, + approval_status=ApprovalStatusEnum.APPROVED + ) + user.platform_memberships.append(membership) + db_session.add(user) + db_session.commit() + return db_users + + +def test_get_users(test_client, as_admin_user, dummy_platform, + mock_auth0_client, test_db_session, persistent_factories): + """ + Test getting a list of users. The list should only contain users with platform memberships + that the admin user has access to. + """ + valid_users = _users_with_platform_membership(n=3, db_session=test_db_session, platform_id=dummy_platform.id) + other_platform = PlatformFactory.create_sync(platform_id=PlatformEnum.BPA_DATA_PORTAL) + invalid_users = _users_with_platform_membership(n=2, db_session=test_db_session, platform_id=other_platform.id) resp = test_client.get("/admin/users") assert resp.status_code == 200 - assert len(resp.json()) == 3 + data = resp.json() + assert len(data) == 3 + user_ids = [u["id"] for u in data] + assert all(u.id in user_ids for u in valid_users) + assert all(u.id not in user_ids for u in invalid_users) -def test_get_users_pagination_params(test_client, as_admin_user, mock_auth0_client, test_db_session): - # Create some test users in the database - db_users = BiocommonsUserFactory.batch(3) - for user in db_users: - test_db_session.add(user) - test_db_session.commit() +def test_get_users_pagination_params(test_client, as_admin_user, dummy_platform, mock_auth0_client, test_db_session): + _users_with_platform_membership(n=3, db_session=test_db_session, platform_id=dummy_platform.id) resp = test_client.get("/admin/users?page=2&per_page=10") assert resp.status_code == 200 @@ -96,79 +142,62 @@ def test_get_users_pagination_params(test_client, as_admin_user, mock_auth0_clie assert len(resp.json()) == 0 -def test_get_users_invalid_params(test_client, as_admin_user, mock_auth0_client): - users = Auth0UserDataFactory.batch(3) - mock_auth0_client.get_users.return_value = users +def test_get_users_invalid_params(test_client, as_admin_user, dummy_platform, test_db_session, mock_auth0_client): + _users_with_platform_membership(n=3, db_session=test_db_session, platform_id=dummy_platform.id) resp = test_client.get("/admin/users?page=0&per_page=500") assert resp.status_code == 422 error_msg = resp.json()["detail"] assert "Invalid page params" in error_msg -def test_get_users_filter_by_platform(test_client, as_admin_user, test_db_session): - from db.models import ApprovalStatusEnum, PlatformEnum, PlatformMembership - from tests.db.datagen import BiocommonsUserFactory - - galaxy_users = BiocommonsUserFactory.batch(2) +def test_get_users_filter_by_platform(test_client, as_admin_user, dummy_platform, test_db_session, + persistent_factories): + galaxy_users = _users_with_platform_membership(n=2, db_session=test_db_session, platform_id=dummy_platform.id) other_users = BiocommonsUserFactory.batch(2) - - for user in galaxy_users + other_users: + for user in other_users: test_db_session.add(user) test_db_session.commit() - for user in galaxy_users: - membership = PlatformMembership( - user_id=user.id, - platform_id=PlatformEnum.GALAXY, - approval_status=ApprovalStatusEnum.APPROVED - ) - test_db_session.add(membership) - test_db_session.commit() - resp = test_client.get("/admin/users?filter_by=galaxy") assert resp.status_code == 200 - assert len(resp.json()) == 2 + galaxy_data = resp.json() + assert len(galaxy_data) == 2 + galaxy_ids = [u["id"] for u in galaxy_data] + assert all(u.id in galaxy_ids for u in galaxy_users) resp = test_client.get("/admin/users?filter_by=bpa_data_portal") assert resp.status_code == 200 assert len(resp.json()) == 0 -def test_get_users_filter_by_group(test_client, as_admin_user, test_db_session): - from db.models import ( - ApprovalStatusEnum, - BiocommonsGroup, - GroupMembership, - ) - from db.types import GroupEnum - from tests.db.datagen import BiocommonsUserFactory - +def test_get_users_filter_by_group(test_client, as_admin_user, dummy_platform, test_db_session): tsi_group = BiocommonsGroup( group_id=GroupEnum.TSI, name="Threatened Species Initiative Bundle" ) test_db_session.add(tsi_group) test_db_session.commit() - - tsi_users = BiocommonsUserFactory.batch(2) - other_users = BiocommonsUserFactory.batch(2) - - for user in tsi_users + other_users: - test_db_session.add(user) - test_db_session.commit() + # Create users who can be managed by the admin user + tsi_users = _users_with_platform_membership(n=3, db_session=test_db_session, platform_id=dummy_platform.id) + other_users = _users_with_platform_membership(n=2, db_session=test_db_session, platform_id=dummy_platform.id) for user in tsi_users: - membership = GroupMembership( + membership = GroupMembershipFactory.create_sync( user_id=user.id, group_id=GroupEnum.TSI, approval_status=ApprovalStatusEnum.APPROVED ) + user.group_memberships.append(membership) test_db_session.add(membership) test_db_session.commit() resp = test_client.get("/admin/users?filter_by=tsi") assert resp.status_code == 200 - assert len(resp.json()) == 2 + tsi_data = resp.json() + assert len(tsi_data) == 3 + tsi_ids = [u["id"] for u in tsi_data] + assert all(u.id in tsi_ids for u in tsi_users) + assert all(u.id not in tsi_ids for u in other_users) resp = test_client.get("/admin/users?filter_by=bpa_galaxy") assert resp.status_code == 404 @@ -181,34 +210,41 @@ def test_get_users_invalid_filter(test_client, as_admin_user, test_db_session): assert "Invalid filter_by value 'invalid_filter'" in resp.json()["detail"] -def test_get_users_search_by_email_exact(test_client, as_admin_user, test_db_session): - from tests.db.datagen import BiocommonsUserFactory - - user1 = BiocommonsUserFactory.build(email="john.doe@example.com", username="johndoe") - user2 = BiocommonsUserFactory.build(email="jane.smith@example.com", username="janesmith") - user3 = BiocommonsUserFactory.build(email="bob.wilson@example.com", username="bobwilson") - - for user in [user1, user2, user3]: - test_db_session.add(user) - test_db_session.commit() +def test_get_users_search_by_email_exact(test_client, as_admin_user, dummy_platform, test_db_session): + user1 = _create_user_with_platform_membership( + db_session=test_db_session, platform_id=dummy_platform.id, + email="john.doe@example.com", username="johndoe" + ) + _create_user_with_platform_membership( + db_session=test_db_session, platform_id=dummy_platform.id, + email="jane.smith@example.com", username="janesmith" + ) + _create_user_with_platform_membership( + db_session=test_db_session, platform_id=dummy_platform.id, + email="bob.wilson@example.com", username="bobwilson" + ) resp = test_client.get("/admin/users?search=john.doe@example.com") assert resp.status_code == 200 results = resp.json() assert len(results) == 1 assert results[0]["email"] == "john.doe@example.com" + assert results[0]["id"] == user1.id -def test_get_users_search_by_email_partial(test_client, as_admin_user, test_db_session): - from tests.db.datagen import BiocommonsUserFactory - - user1 = BiocommonsUserFactory.build(email="john.doe@example.com", username="johndoe") - user2 = BiocommonsUserFactory.build(email="jane.smith@example.com", username="janesmith") - user3 = BiocommonsUserFactory.build(email="bob.wilson@different.com", username="bobwilson") - - for user in [user1, user2, user3]: - test_db_session.add(user) - test_db_session.commit() +def test_get_users_search_by_email_partial(test_client, as_admin_user, dummy_platform, test_db_session): + _create_user_with_platform_membership( + db_session=test_db_session, platform_id=dummy_platform.id, + email="john.doe@example.com", username="johndoe" + ) + _create_user_with_platform_membership( + db_session=test_db_session, platform_id=dummy_platform.id, + email="jane.smith@example.com", username="janesmith" + ) + _create_user_with_platform_membership( + db_session=test_db_session, platform_id=dummy_platform.id, + email="bob.wilson@different.com", username="bobwilson" + ) resp = test_client.get("/admin/users?search=example.com") assert resp.status_code == 200 @@ -220,16 +256,19 @@ def test_get_users_search_by_email_partial(test_client, as_admin_user, test_db_s assert "bob.wilson@different.com" not in emails -def test_get_users_search_by_username(test_client, as_admin_user, test_db_session): - from tests.db.datagen import BiocommonsUserFactory - - user1 = BiocommonsUserFactory.build(email="john@example.com", username="johndoe") - user2 = BiocommonsUserFactory.build(email="jane@example.com", username="janesmith") - user3 = BiocommonsUserFactory.build(email="bob@example.com", username="bobwilson") - - for user in [user1, user2, user3]: - test_db_session.add(user) - test_db_session.commit() +def test_get_users_search_by_username(test_client, as_admin_user, dummy_platform, test_db_session): + _create_user_with_platform_membership( + db_session=test_db_session, platform_id=dummy_platform.id, + email="john.doe@example.com", username="johndoe" + ) + _create_user_with_platform_membership( + db_session=test_db_session, platform_id=dummy_platform.id, + email="jane.smith@example.com", username="janesmith" + ) + _create_user_with_platform_membership( + db_session=test_db_session, platform_id=dummy_platform.id, + email="bob.wilson@different.com", username="bobwilson" + ) resp = test_client.get("/admin/users?search=john") assert resp.status_code == 200 @@ -238,16 +277,19 @@ def test_get_users_search_by_username(test_client, as_admin_user, test_db_sessio assert results[0]["username"] == "johndoe" -def test_get_users_search_by_username_partial(test_client, as_admin_user, test_db_session): - from tests.db.datagen import BiocommonsUserFactory - - user1 = BiocommonsUserFactory.build(email="john@example.com", username="johnsmith") - user2 = BiocommonsUserFactory.build(email="jane@example.com", username="johndoe") - user3 = BiocommonsUserFactory.build(email="bob@example.com", username="bobwilson") - - for user in [user1, user2, user3]: - test_db_session.add(user) - test_db_session.commit() +def test_get_users_search_by_username_partial(test_client, as_admin_user, dummy_platform, test_db_session): + _create_user_with_platform_membership( + db_session=test_db_session, platform_id=dummy_platform.id, + email="john.doe@example.com", username="johndoe" + ) + _create_user_with_platform_membership( + db_session=test_db_session, platform_id=dummy_platform.id, + email="smith@example.com", username="johnsmith" + ) + _create_user_with_platform_membership( + db_session=test_db_session, platform_id=dummy_platform.id, + email="bob.wilson@example.com", username="bobwilson" + ) resp = test_client.get("/admin/users?search=john") assert resp.status_code == 200 @@ -259,13 +301,11 @@ def test_get_users_search_by_username_partial(test_client, as_admin_user, test_d assert "bobwilson" not in usernames -def test_get_users_search_case_insensitive(test_client, as_admin_user, test_db_session): - from tests.db.datagen import BiocommonsUserFactory - - user1 = BiocommonsUserFactory.build(email="John.Doe@Example.Com", username="JohnDoe") - - test_db_session.add(user1) - test_db_session.commit() +def test_get_users_search_case_insensitive(test_client, as_admin_user, dummy_platform, test_db_session): + _create_user_with_platform_membership( + db_session=test_db_session, platform_id=dummy_platform.id, + email="John.Doe@Example.Com", username="JohnDoe" + ) resp = test_client.get("/admin/users?search=JOHN") assert resp.status_code == 200 @@ -280,11 +320,11 @@ def test_get_users_search_case_insensitive(test_client, as_admin_user, test_db_s assert results[0]["email"] == "John.Doe@Example.Com" -def test_get_users_search_empty_string(test_client, as_admin_user, test_db_session): - from tests.db.datagen import BiocommonsUserFactory - +def test_get_users_search_empty_string(test_client, as_admin_user, dummy_platform, test_db_session): users = BiocommonsUserFactory.batch(3) for user in users: + membership = PlatformMembershipFactory.create_sync(user_id=user.id, platform_id=dummy_platform.id) + user.platform_memberships.append(membership) test_db_session.add(user) test_db_session.commit() @@ -299,24 +339,16 @@ def test_get_users_search_empty_string(test_client, as_admin_user, test_db_sessi assert len(results) == 3 -def test_get_users_search_with_filter(test_client, as_admin_user, test_db_session): - from db.models import ApprovalStatusEnum, PlatformEnum, PlatformMembership - from tests.db.datagen import BiocommonsUserFactory - - user1 = BiocommonsUserFactory.build(email="john@example.com", username="johndoe") - user2 = BiocommonsUserFactory.build(email="jane@example.com", username="janesmith") - - for user in [user1, user2]: - test_db_session.add(user) - test_db_session.commit() - - membership = PlatformMembership( - user_id=user1.id, - platform_id=PlatformEnum.GALAXY, - approval_status=ApprovalStatusEnum.APPROVED +def test_get_users_search_with_filter(test_client, as_admin_user, dummy_platform, test_db_session): + _create_user_with_platform_membership( + db_session=test_db_session, platform_id=dummy_platform.id, + email="john.doe@example.com", username="johndoe" + ) + other_platform = PlatformFactory.create_sync(id=PlatformEnum.BPA_DATA_PORTAL) + _create_user_with_platform_membership( + db_session=test_db_session, platform_id=other_platform.id, + email="jane@example.com", username="janesmith" ) - test_db_session.add(membership) - test_db_session.commit() resp = test_client.get("/admin/users?filter_by=galaxy&search=john") assert resp.status_code == 200 @@ -390,7 +422,9 @@ def test_get_pending_users(test_client, test_db_session, as_admin_user, persiste def test_get_revoked_users(test_client, test_db_session, as_admin_user, persistent_factories): revoked_users = BiocommonsUserFactory.create_batch_sync(3) for u in revoked_users: - PlatformMembershipFactory.create_sync(user=u, platform_id=PlatformEnum.GALAXY, approval_status=ApprovalStatusEnum.REVOKED) + PlatformMembershipFactory.create_sync( + user=u, platform_id=PlatformEnum.GALAXY, approval_status=ApprovalStatusEnum.REVOKED + ) test_db_session.commit() resp = test_client.get("/admin/users/revoked") assert resp.status_code == 200 From 338a1d715871793349d920b4f2378a9d9101e694 Mon Sep 17 00:00:00 2001 From: marius-mather Date: Fri, 26 Sep 2025 11:20:19 +1000 Subject: [PATCH 14/16] Style fix --- tests/test_admin.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_admin.py b/tests/test_admin.py index 00efd221..d4141fa4 100644 --- a/tests/test_admin.py +++ b/tests/test_admin.py @@ -320,7 +320,7 @@ def test_get_users_search_case_insensitive(test_client, as_admin_user, dummy_pla assert results[0]["email"] == "John.Doe@Example.Com" -def test_get_users_search_empty_string(test_client, as_admin_user, dummy_platform, test_db_session): +def test_get_users_search_empty_string(test_client, as_admin_user, dummy_platform, test_db_session): users = BiocommonsUserFactory.batch(3) for user in users: membership = PlatformMembershipFactory.create_sync(user_id=user.id, platform_id=dummy_platform.id) From 7a995810ff3246f6ba3b1846f184d522214f43c8 Mon Sep 17 00:00:00 2001 From: marius-mather Date: Fri, 26 Sep 2025 11:26:13 +1000 Subject: [PATCH 15/16] Fix creating platform in test --- tests/test_admin.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/test_admin.py b/tests/test_admin.py index d4141fa4..e0928f5f 100644 --- a/tests/test_admin.py +++ b/tests/test_admin.py @@ -39,6 +39,7 @@ def dummy_platform(persistent_factories): admin_role = Auth0RoleFactory.create_sync(name="Admin") return PlatformFactory.create_sync( id=PlatformEnum.GALAXY, + name="Galaxy Australia", admin_roles=[admin_role] ) @@ -121,7 +122,7 @@ def test_get_users(test_client, as_admin_user, dummy_platform, that the admin user has access to. """ valid_users = _users_with_platform_membership(n=3, db_session=test_db_session, platform_id=dummy_platform.id) - other_platform = PlatformFactory.create_sync(platform_id=PlatformEnum.BPA_DATA_PORTAL) + other_platform = PlatformFactory.create_sync(id=PlatformEnum.BPA_DATA_PORTAL) invalid_users = _users_with_platform_membership(n=2, db_session=test_db_session, platform_id=other_platform.id) resp = test_client.get("/admin/users") From 9f2df65867c9d55e5801d8a36d8e2d2058a8ef99 Mon Sep 17 00:00:00 2001 From: marius-mather Date: Fri, 26 Sep 2025 12:36:51 +1000 Subject: [PATCH 16/16] Comment on checking platform membership creation --- tests/db/test_models.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/db/test_models.py b/tests/db/test_models.py index f9c7b847..6393a3c2 100644 --- a/tests/db/test_models.py +++ b/tests/db/test_models.py @@ -140,6 +140,7 @@ def test_create_platform_membership(test_db_session, persistent_factories, froze assert membership.user == user assert membership.approval_status == ApprovalStatusEnum.APPROVED assert membership.platform_id == "galaxy" + # Check the related platform object is populated assert membership.platform.id == "galaxy" assert membership.updated_at == FROZEN_TIME