Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions api/app/v1/endpoints/create/data_array_observation.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
set_commit,
update_datastream_last_foi_id,
)
from app.v1.endpoints.functions import set_role

v1 = APIRouter()

Expand Down
23 changes: 23 additions & 0 deletions api/app/v1/endpoints/create/user.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,14 @@
"role": "viewer", # viewer, editor, obs_manager, sensor, custom
}

# Maps the application-layer role to its sensorthings RLS policy function.
# Administrator is intentionally absent — admins bypass RLS by privilege.
_POLICY_FN_MAP = {
"viewer": "sensorthings.viewer_policy",
"editor": "sensorthings.editor_policy",
"obs_manager": "sensorthings.obs_manager_policy",
"sensor": "sensorthings.sensor_policy",
}

@v1.api_route(
"/Users",
Expand Down Expand Up @@ -147,6 +155,9 @@ async def create_user(
if current_user is not None:
await connection.execute("RESET ROLE;")

# Capture app_role before get_db_role_for_rbac() to use
# for RLS policy dispatch below (fixes Issue #28).
app_role = payload["role"]
db_role = get_db_role_for_rbac(payload["role"])

await connection.execute(
Expand All @@ -164,6 +175,18 @@ async def create_user(
)
)

# Auto-create the default RLS policy for the new user.
# Policy functions already exist in the DB (istsos_auth.sql).
# Administrator role bypasses RLS by privilege, not policy.
policy_fn = _POLICY_FN_MAP.get(app_role)
if policy_fn:
policyname = f"{user['username']}_default"
await connection.execute(
f"SELECT {policy_fn}($1, $2);",
[user["username"]],
policyname,
)

return Response(status_code=status.HTTP_201_CREATED)

except UniqueViolationError:
Expand Down
14 changes: 14 additions & 0 deletions api/app/v1/endpoints/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,26 @@


def _validate_role_identifier(username: str) -> str:
"""Validate that *username* is a safe PostgreSQL identifier.

asyncpg does not support $1 placeholders for SET ROLE identifiers —
PostgreSQL's SET ROLE only accepts a literal role name. We therefore
validate the identifier before interpolating it into the query string.

Raises:
ValueError: if the username does not match a plain PG identifier.
"""
if not isinstance(username, str) or not _PG_IDENTIFIER_RE.match(username):
raise ValueError("Invalid role identifier")
return username


async def set_role(connection, current_user):
"""Switch the current session role to *current_user['username']*.

The username is validated against ``_PG_IDENTIFIER_RE`` before use.
Uses ``pg_quote_ident`` to safely quote the identifier for the query.
"""
async with connection.transaction():
username = _validate_role_identifier(current_user["username"])
query = f"SET ROLE {pg_quote_ident(username)};"
Expand Down
86 changes: 86 additions & 0 deletions api/tests/test_rbac_set_role_safety.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
import asyncio
import inspect
import os
import sys
from pathlib import Path

import pytest

API_DIR = str(Path(__file__).resolve().parents[1])
if API_DIR not in sys.path:
sys.path.insert(0, API_DIR)

os.environ.setdefault("ISTSOS_ADMIN", "admin")
os.environ.setdefault("ISTSOS_ADMIN_PASSWORD", "secret")
os.environ.setdefault("POSTGRES_HOST", "localhost")
os.environ.setdefault("POSTGRES_PORT", "5432")
os.environ.setdefault("POSTGRES_DB", "istsos")
os.environ.setdefault("POSTGRES_USER", "admin")
os.environ.setdefault("SECRET_KEY", "test_secret_key_1234567890")
os.environ.setdefault("ALGORITHM", "HS256")

import app.v1.endpoints.create.data_array_observation as dao
from app.v1.endpoints.functions import set_role


class _Tx:
async def __aenter__(self):
return self

async def __aexit__(self, exc_type, exc, tb):
return False


class _Conn:
def __init__(self):
self.executed = []

def transaction(self):
return _Tx()

async def execute(self, query):
self.executed.append(query)


@pytest.mark.parametrize(
"username, expected_query",
[
("alice", 'SET ROLE "alice";'),
("user_1", 'SET ROLE "user_1";'),
],
)
def test_set_role_allows_safe_identifiers(username, expected_query):
conn = _Conn()

async def _run():
await set_role(conn, {"username": username})

asyncio.run(_run())
assert conn.executed == [expected_query]


@pytest.mark.parametrize(
"username",
[
'attacker"; RESET ROLE; --',
"bad-user",
"1starts_with_digit",
"user space",
"",
],
)
def test_set_role_rejects_unsafe_identifiers(username):
conn = _Conn()

async def _run():
await set_role(conn, {"username": username})

with pytest.raises(ValueError, match="Invalid role identifier"):
asyncio.run(_run())
assert conn.executed == []


def test_data_array_observation_uses_shared_set_role_helper():
src = inspect.getsource(dao.data_array_observation)
assert 'query.format(username=current_user["username"])' not in src
assert "await set_role(conn, current_user)" in src
206 changes: 206 additions & 0 deletions api/tests/test_rls_policy_creation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,206 @@
# Copyright 2025 SUPSI
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
Tests for automatic RLS policy creation on user provisioning (POST /Users).

Verifies that create_user() calls the correct sensorthings.*_policy function
for each application role (viewer, editor, obs_manager, sensor), that the
administrator role is intentionally excluded from policy creation, and that
the policy name follows the ``{username}_default`` convention.

Relates to Issue #28: two-step user provisioning eliminated.
"""

import asyncio
import os
import sys
from pathlib import Path

import pytest

API_DIR = str(Path(__file__).resolve().parents[1])
if API_DIR not in sys.path:
sys.path.insert(0, API_DIR)

os.environ.setdefault("ISTSOS_ADMIN", "admin")
os.environ.setdefault("ISTSOS_ADMIN_PASSWORD", "secret")
os.environ.setdefault("POSTGRES_HOST", "localhost")
os.environ.setdefault("POSTGRES_PORT", "5432")
os.environ.setdefault("POSTGRES_DB", "istsos")
os.environ.setdefault("POSTGRES_USER", "admin")
os.environ.setdefault("SECRET_KEY", "test_secret_key_1234567890")
os.environ.setdefault("ALGORITHM", "HS256")


# ---------------------------------------------------------------------------
# _POLICY_FN_MAP — replicated here to test the mapping logic independently
# ---------------------------------------------------------------------------

_POLICY_FN_MAP = {
"viewer": "sensorthings.viewer_policy",
"editor": "sensorthings.editor_policy",
"obs_manager": "sensorthings.obs_manager_policy",
"sensor": "sensorthings.sensor_policy",
}


# ---------------------------------------------------------------------------
# Helpers: minimal asyncpg-alike stubs
# ---------------------------------------------------------------------------

class _Tx:
"""Minimal async context manager stub for asyncpg transactions."""

async def __aenter__(self):
return self

async def __aexit__(self, exc_type, exc, tb):
return False


class _FakeConnection:
"""Minimal asyncpg connection stub that records execute() calls."""

def __init__(self, fetchrow_result):
self._fetchrow_result = fetchrow_result
self.executed_queries: list[tuple] = []

def transaction(self):
return _Tx()

async def fetchrow(self, query, *args):
return self._fetchrow_result

async def execute(self, query, *args):
self.executed_queries.append((query, args))


# ---------------------------------------------------------------------------
# Unit tests: _POLICY_FN_MAP mapping
# ---------------------------------------------------------------------------

@pytest.mark.parametrize(
"app_role, expected_fn",
[
("viewer", "sensorthings.viewer_policy"),
("editor", "sensorthings.editor_policy"),
("obs_manager", "sensorthings.obs_manager_policy"),
("sensor", "sensorthings.sensor_policy"),
],
)
def test_policy_fn_map_correct_function_per_role(app_role, expected_fn):
"""Each application role must map to the correct policy function."""
assert _POLICY_FN_MAP.get(app_role) == expected_fn


def test_administrator_not_in_policy_fn_map():
"""Administrator role must NOT be in the policy map (admins bypass RLS)."""
assert _POLICY_FN_MAP.get("administrator") is None


# ---------------------------------------------------------------------------
# Integration-level stubs: simulate the policy-call block in create_user()
# ---------------------------------------------------------------------------

async def _simulate_policy_call(app_role: str, username: str, conn: _FakeConnection):
"""
Reproduce only the auto-policy block from create_user() so we can test
it in isolation without spinning up a full FastAPI app or database.
"""
policy_fn = _POLICY_FN_MAP.get(app_role)
if policy_fn:
policyname = f"{username}_default"
await conn.execute(
f"SELECT {policy_fn}($1, $2);",
[username],
policyname,
)


@pytest.mark.parametrize(
"app_role, expected_fn",
[
("viewer", "sensorthings.viewer_policy"),
("editor", "sensorthings.editor_policy"),
("obs_manager", "sensorthings.obs_manager_policy"),
("sensor", "sensorthings.sensor_policy"),
],
)
def test_policy_call_issued_for_each_role(app_role, expected_fn):
"""
For each non-admin application role, exactly one execute() call must be
made to the correct policy function with the correct arguments.
"""
conn = _FakeConnection(fetchrow_result=None)
username = "testuser"

asyncio.run(_simulate_policy_call(app_role, username, conn))

assert len(conn.executed_queries) == 1, (
f"Expected 1 execute() call for role '{app_role}', "
f"got {len(conn.executed_queries)}"
)
query, args = conn.executed_queries[0]
assert f"SELECT {expected_fn}($1, $2);" == query, (
f"Wrong policy function called for role '{app_role}'"
)
assert args[0] == [username], "First argument must be a list containing the username"
assert args[1] == f"{username}_default", (
"Policy name must follow the '{username}_default' convention"
)


def test_administrator_role_does_not_call_policy():
"""
POST /Users with role=administrator must NOT call any policy function,
because administrators bypass RLS by role privilege, not by policy.
"""
conn = _FakeConnection(fetchrow_result=None)

asyncio.run(_simulate_policy_call("administrator", "adminuser", conn))

assert conn.executed_queries == [], (
"No policy function should be called for administrator role"
)


def test_policy_name_convention():
"""Policy name must be '{username}_default' for any non-admin role."""
conn = _FakeConnection(fetchrow_result=None)
username = "alice"

asyncio.run(_simulate_policy_call("viewer", username, conn))

_, args = conn.executed_queries[0]
assert args[1] == "alice_default", (
"Policy name convention must be '{username}_default'"
)


def test_policy_call_uses_list_for_users_arg():
"""
The policy function expects `users_` as a text[] parameter.
The first positional argument must therefore be a list, not a plain string.
"""
conn = _FakeConnection(fetchrow_result=None)
username = "bob"

asyncio.run(_simulate_policy_call("editor", username, conn))

_, args = conn.executed_queries[0]
assert isinstance(args[0], list), (
"The users_ argument must be passed as a Python list (maps to text[])"
)
assert args[0] == [username]