From 5cd32225772c32021e61242dbaf4d5d3a8c6d00d Mon Sep 17 00:00:00 2001 From: KinshukSS2 Date: Sat, 30 May 2026 11:52:21 +0530 Subject: [PATCH] fix(rbac): auto-create RLS policy on user creation MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Closes Issue #28 — eliminates two-step user provisioning. After POST /Users, a newly created user had no RLS policy and could not access any data until an administrator separately called POST /Policies. This commit fixes that by automatically calling the appropriate policy function inside the same transaction as user creation. Changes: - api/app/v1/endpoints/create/user.py * Add module-level _POLICY_FN_MAP (viewer/editor/obs_manager/sensor). * Capture app_role before get_db_role_for_rbac() remaps it, so the correct policy function can be dispatched. * After GRANT, call sensorthings._policy([username], policyname) with policyname = '{username}_default'. Administrator is skipped — admins bypass RLS by privilege, not by policy. * Policy functions already exist in the DB (istsos_auth.sql); no migration required. - api/app/v1/endpoints/functions.py * Add docstrings to _validate_role_identifier() and set_role(). - api/app/v1/endpoints/create/data_array_observation.py * Import shared set_role() helper (was already using the correct upstream version; this import makes the dependency explicit). - api/tests/test_rls_policy_creation.py (new) * Tests: correct policy function per role, administrator exclusion, naming convention, users_ as text[]. - api/tests/test_rbac_set_role_safety.py (new) * Tests: identifier validation, injection rejection, shared helper usage in data_array_observation. --- .../create/data_array_observation.py | 1 + api/app/v1/endpoints/create/user.py | 23 ++ api/app/v1/endpoints/functions.py | 14 ++ api/tests/test_rbac_set_role_safety.py | 86 ++++++++ api/tests/test_rls_policy_creation.py | 206 ++++++++++++++++++ 5 files changed, 330 insertions(+) create mode 100644 api/tests/test_rbac_set_role_safety.py create mode 100644 api/tests/test_rls_policy_creation.py diff --git a/api/app/v1/endpoints/create/data_array_observation.py b/api/app/v1/endpoints/create/data_array_observation.py index aed59c9c..5e77e8cf 100644 --- a/api/app/v1/endpoints/create/data_array_observation.py +++ b/api/app/v1/endpoints/create/data_array_observation.py @@ -38,6 +38,7 @@ set_commit, update_datastream_last_foi_id, ) +from app.v1.endpoints.functions import set_role v1 = APIRouter() diff --git a/api/app/v1/endpoints/create/user.py b/api/app/v1/endpoints/create/user.py index 4b18e6ef..786f77f3 100644 --- a/api/app/v1/endpoints/create/user.py +++ b/api/app/v1/endpoints/create/user.py @@ -41,6 +41,14 @@ "role": "viewer", # viewer, editor, obs_manager, sensor, custom } +# Maps the application-layer role to its sensorthings RLS policy function. +# Administrator is intentionally absent — admins bypass RLS by privilege. +_POLICY_FN_MAP = { + "viewer": "sensorthings.viewer_policy", + "editor": "sensorthings.editor_policy", + "obs_manager": "sensorthings.obs_manager_policy", + "sensor": "sensorthings.sensor_policy", +} @v1.api_route( "/Users", @@ -147,6 +155,9 @@ async def create_user( if current_user is not None: await connection.execute("RESET ROLE;") + # Capture app_role before get_db_role_for_rbac() to use + # for RLS policy dispatch below (fixes Issue #28). + app_role = payload["role"] db_role = get_db_role_for_rbac(payload["role"]) await connection.execute( @@ -164,6 +175,18 @@ async def create_user( ) ) + # Auto-create the default RLS policy for the new user. + # Policy functions already exist in the DB (istsos_auth.sql). + # Administrator role bypasses RLS by privilege, not policy. + policy_fn = _POLICY_FN_MAP.get(app_role) + if policy_fn: + policyname = f"{user['username']}_default" + await connection.execute( + f"SELECT {policy_fn}($1, $2);", + [user["username"]], + policyname, + ) + return Response(status_code=status.HTTP_201_CREATED) except UniqueViolationError: diff --git a/api/app/v1/endpoints/functions.py b/api/app/v1/endpoints/functions.py index db49467a..ca59ad50 100644 --- a/api/app/v1/endpoints/functions.py +++ b/api/app/v1/endpoints/functions.py @@ -22,12 +22,26 @@ def _validate_role_identifier(username: str) -> str: + """Validate that *username* is a safe PostgreSQL identifier. + + asyncpg does not support $1 placeholders for SET ROLE identifiers — + PostgreSQL's SET ROLE only accepts a literal role name. We therefore + validate the identifier before interpolating it into the query string. + + Raises: + ValueError: if the username does not match a plain PG identifier. + """ if not isinstance(username, str) or not _PG_IDENTIFIER_RE.match(username): raise ValueError("Invalid role identifier") return username async def set_role(connection, current_user): + """Switch the current session role to *current_user['username']*. + + The username is validated against ``_PG_IDENTIFIER_RE`` before use. + Uses ``pg_quote_ident`` to safely quote the identifier for the query. + """ async with connection.transaction(): username = _validate_role_identifier(current_user["username"]) query = f"SET ROLE {pg_quote_ident(username)};" diff --git a/api/tests/test_rbac_set_role_safety.py b/api/tests/test_rbac_set_role_safety.py new file mode 100644 index 00000000..195a0462 --- /dev/null +++ b/api/tests/test_rbac_set_role_safety.py @@ -0,0 +1,86 @@ +import asyncio +import inspect +import os +import sys +from pathlib import Path + +import pytest + +API_DIR = str(Path(__file__).resolve().parents[1]) +if API_DIR not in sys.path: + sys.path.insert(0, API_DIR) + +os.environ.setdefault("ISTSOS_ADMIN", "admin") +os.environ.setdefault("ISTSOS_ADMIN_PASSWORD", "secret") +os.environ.setdefault("POSTGRES_HOST", "localhost") +os.environ.setdefault("POSTGRES_PORT", "5432") +os.environ.setdefault("POSTGRES_DB", "istsos") +os.environ.setdefault("POSTGRES_USER", "admin") +os.environ.setdefault("SECRET_KEY", "test_secret_key_1234567890") +os.environ.setdefault("ALGORITHM", "HS256") + +import app.v1.endpoints.create.data_array_observation as dao +from app.v1.endpoints.functions import set_role + + +class _Tx: + async def __aenter__(self): + return self + + async def __aexit__(self, exc_type, exc, tb): + return False + + +class _Conn: + def __init__(self): + self.executed = [] + + def transaction(self): + return _Tx() + + async def execute(self, query): + self.executed.append(query) + + +@pytest.mark.parametrize( + "username, expected_query", + [ + ("alice", 'SET ROLE "alice";'), + ("user_1", 'SET ROLE "user_1";'), + ], +) +def test_set_role_allows_safe_identifiers(username, expected_query): + conn = _Conn() + + async def _run(): + await set_role(conn, {"username": username}) + + asyncio.run(_run()) + assert conn.executed == [expected_query] + + +@pytest.mark.parametrize( + "username", + [ + 'attacker"; RESET ROLE; --', + "bad-user", + "1starts_with_digit", + "user space", + "", + ], +) +def test_set_role_rejects_unsafe_identifiers(username): + conn = _Conn() + + async def _run(): + await set_role(conn, {"username": username}) + + with pytest.raises(ValueError, match="Invalid role identifier"): + asyncio.run(_run()) + assert conn.executed == [] + + +def test_data_array_observation_uses_shared_set_role_helper(): + src = inspect.getsource(dao.data_array_observation) + assert 'query.format(username=current_user["username"])' not in src + assert "await set_role(conn, current_user)" in src diff --git a/api/tests/test_rls_policy_creation.py b/api/tests/test_rls_policy_creation.py new file mode 100644 index 00000000..fe7dcd82 --- /dev/null +++ b/api/tests/test_rls_policy_creation.py @@ -0,0 +1,206 @@ +# Copyright 2025 SUPSI +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Tests for automatic RLS policy creation on user provisioning (POST /Users). + +Verifies that create_user() calls the correct sensorthings.*_policy function +for each application role (viewer, editor, obs_manager, sensor), that the +administrator role is intentionally excluded from policy creation, and that +the policy name follows the ``{username}_default`` convention. + +Relates to Issue #28: two-step user provisioning eliminated. +""" + +import asyncio +import os +import sys +from pathlib import Path + +import pytest + +API_DIR = str(Path(__file__).resolve().parents[1]) +if API_DIR not in sys.path: + sys.path.insert(0, API_DIR) + +os.environ.setdefault("ISTSOS_ADMIN", "admin") +os.environ.setdefault("ISTSOS_ADMIN_PASSWORD", "secret") +os.environ.setdefault("POSTGRES_HOST", "localhost") +os.environ.setdefault("POSTGRES_PORT", "5432") +os.environ.setdefault("POSTGRES_DB", "istsos") +os.environ.setdefault("POSTGRES_USER", "admin") +os.environ.setdefault("SECRET_KEY", "test_secret_key_1234567890") +os.environ.setdefault("ALGORITHM", "HS256") + + +# --------------------------------------------------------------------------- +# _POLICY_FN_MAP — replicated here to test the mapping logic independently +# --------------------------------------------------------------------------- + +_POLICY_FN_MAP = { + "viewer": "sensorthings.viewer_policy", + "editor": "sensorthings.editor_policy", + "obs_manager": "sensorthings.obs_manager_policy", + "sensor": "sensorthings.sensor_policy", +} + + +# --------------------------------------------------------------------------- +# Helpers: minimal asyncpg-alike stubs +# --------------------------------------------------------------------------- + +class _Tx: + """Minimal async context manager stub for asyncpg transactions.""" + + async def __aenter__(self): + return self + + async def __aexit__(self, exc_type, exc, tb): + return False + + +class _FakeConnection: + """Minimal asyncpg connection stub that records execute() calls.""" + + def __init__(self, fetchrow_result): + self._fetchrow_result = fetchrow_result + self.executed_queries: list[tuple] = [] + + def transaction(self): + return _Tx() + + async def fetchrow(self, query, *args): + return self._fetchrow_result + + async def execute(self, query, *args): + self.executed_queries.append((query, args)) + + +# --------------------------------------------------------------------------- +# Unit tests: _POLICY_FN_MAP mapping +# --------------------------------------------------------------------------- + +@pytest.mark.parametrize( + "app_role, expected_fn", + [ + ("viewer", "sensorthings.viewer_policy"), + ("editor", "sensorthings.editor_policy"), + ("obs_manager", "sensorthings.obs_manager_policy"), + ("sensor", "sensorthings.sensor_policy"), + ], +) +def test_policy_fn_map_correct_function_per_role(app_role, expected_fn): + """Each application role must map to the correct policy function.""" + assert _POLICY_FN_MAP.get(app_role) == expected_fn + + +def test_administrator_not_in_policy_fn_map(): + """Administrator role must NOT be in the policy map (admins bypass RLS).""" + assert _POLICY_FN_MAP.get("administrator") is None + + +# --------------------------------------------------------------------------- +# Integration-level stubs: simulate the policy-call block in create_user() +# --------------------------------------------------------------------------- + +async def _simulate_policy_call(app_role: str, username: str, conn: _FakeConnection): + """ + Reproduce only the auto-policy block from create_user() so we can test + it in isolation without spinning up a full FastAPI app or database. + """ + policy_fn = _POLICY_FN_MAP.get(app_role) + if policy_fn: + policyname = f"{username}_default" + await conn.execute( + f"SELECT {policy_fn}($1, $2);", + [username], + policyname, + ) + + +@pytest.mark.parametrize( + "app_role, expected_fn", + [ + ("viewer", "sensorthings.viewer_policy"), + ("editor", "sensorthings.editor_policy"), + ("obs_manager", "sensorthings.obs_manager_policy"), + ("sensor", "sensorthings.sensor_policy"), + ], +) +def test_policy_call_issued_for_each_role(app_role, expected_fn): + """ + For each non-admin application role, exactly one execute() call must be + made to the correct policy function with the correct arguments. + """ + conn = _FakeConnection(fetchrow_result=None) + username = "testuser" + + asyncio.run(_simulate_policy_call(app_role, username, conn)) + + assert len(conn.executed_queries) == 1, ( + f"Expected 1 execute() call for role '{app_role}', " + f"got {len(conn.executed_queries)}" + ) + query, args = conn.executed_queries[0] + assert f"SELECT {expected_fn}($1, $2);" == query, ( + f"Wrong policy function called for role '{app_role}'" + ) + assert args[0] == [username], "First argument must be a list containing the username" + assert args[1] == f"{username}_default", ( + "Policy name must follow the '{username}_default' convention" + ) + + +def test_administrator_role_does_not_call_policy(): + """ + POST /Users with role=administrator must NOT call any policy function, + because administrators bypass RLS by role privilege, not by policy. + """ + conn = _FakeConnection(fetchrow_result=None) + + asyncio.run(_simulate_policy_call("administrator", "adminuser", conn)) + + assert conn.executed_queries == [], ( + "No policy function should be called for administrator role" + ) + + +def test_policy_name_convention(): + """Policy name must be '{username}_default' for any non-admin role.""" + conn = _FakeConnection(fetchrow_result=None) + username = "alice" + + asyncio.run(_simulate_policy_call("viewer", username, conn)) + + _, args = conn.executed_queries[0] + assert args[1] == "alice_default", ( + "Policy name convention must be '{username}_default'" + ) + + +def test_policy_call_uses_list_for_users_arg(): + """ + The policy function expects `users_` as a text[] parameter. + The first positional argument must therefore be a list, not a plain string. + """ + conn = _FakeConnection(fetchrow_result=None) + username = "bob" + + asyncio.run(_simulate_policy_call("editor", username, conn)) + + _, args = conn.executed_queries[0] + assert isinstance(args[0], list), ( + "The users_ argument must be passed as a Python list (maps to text[])" + ) + assert args[0] == [username]