diff --git a/auth/validator.py b/auth/validator.py index ae4554a9..f4d9280e 100644 --- a/auth/validator.py +++ b/auth/validator.py @@ -1,5 +1,6 @@ import json import logging +from threading import Lock import httpx import jwt @@ -13,7 +14,9 @@ logger = logging.getLogger("uvicorn.error") -KEY_CACHE = TTLCache(maxsize=10, ttl=30 * 60) +KEY_CACHE_TIMEOUT = 6 * 60 * 60 # 6 hours +KEY_CACHE = TTLCache(maxsize=10, ttl=KEY_CACHE_TIMEOUT) +KEY_CACHE_LOCK = Lock() def verify_jwt(token: str, settings: Settings) -> AccessTokenPayload: @@ -36,6 +39,7 @@ def verify_jwt(token: str, settings: Settings) -> AccessTokenPayload: issuers.append(settings.auth0_issuer) payload = None + last_issuer_error = None for issuer in issuers: try: payload = jwt.decode( @@ -47,14 +51,16 @@ def verify_jwt(token: str, settings: Settings) -> AccessTokenPayload: ) break except InvalidIssuerError as e: - logger.warning(f"JWT rejected due to invalid issuer: {e}") + last_issuer_error = e continue except InvalidTokenError as e: logger.warning(f"JWT rejected during decode: {e}") raise HTTPException(status_code=401, detail="Not authorized") if payload is None: - logger.warning("JWT rejected: issuer validation failed for all configured issuers") + logger.warning( + f"JWT rejected: issuer validation failed for all configured issuers: {last_issuer_error}" + ) raise HTTPException(status_code=401, detail="Not authorized") roles_claim = "https://biocommons.org.au/roles" @@ -67,14 +73,44 @@ def verify_jwt(token: str, settings: Settings) -> AccessTokenPayload: def _fetch_rsa_keys(auth0_domain: str) -> dict: + """ + Try to get cached keys if possible, otherwise + refresh from Auth0 + """ cache_key = f"jwks_{auth0_domain}" if cache_key in KEY_CACHE: return KEY_CACHE[cache_key] - jwks_url = f"https://{auth0_domain}/.well-known/jwks.json" - response = httpx.get(jwks_url) - keys = response.json() - KEY_CACHE[cache_key] = keys - return keys + + # Lock so we don't do the lookup multiple times + # if multiple requests come in while cache is expired + with KEY_CACHE_LOCK: + # Check again: another request may have refreshed while + # this was waiting + cached = KEY_CACHE.get(cache_key, None) + if cached is not None: + return cached + + try: + metadata_url = f"https://{auth0_domain}/.well-known/openid-configuration" + metadata_response = httpx.get(metadata_url) + metadata_response.raise_for_status() + metadata = metadata_response.json() + + jwks_url = metadata["jwks_uri"] + response = httpx.get(jwks_url) + response.raise_for_status() + keys = response.json() + except KeyError as exc: + logger.error(f"OIDC metadata from {metadata_url} did not include jwks_uri") + raise InvalidTokenError("Failed to fetch JWKS") from exc + except (httpx.HTTPError, ValueError) as exc: + logger.error( + f"Failed to fetch OIDC metadata or JWKS for domain {auth0_domain}: {exc}" + ) + # Do not cache on error + raise InvalidTokenError("Failed to fetch JWKS") from exc + KEY_CACHE[cache_key] = keys + return keys def get_rsa_key(token: str, settings: Settings, retry_on_failure: bool = True): @@ -88,9 +124,10 @@ def get_rsa_key(token: str, settings: Settings, retry_on_failure: bool = True): if key.get("kid") == key_id: return RSAAlgorithm.from_jwk(json.dumps(key)) - # Retry without cache on failure + # Retry without cache on failure (but only once, to prevent infinite retry) if retry_on_failure: - KEY_CACHE.clear() + with KEY_CACHE_LOCK: + KEY_CACHE.clear() return get_rsa_key(token, settings, retry_on_failure=False) return None diff --git a/tests/auth/test_auth_validator.py b/tests/auth/test_auth_validator.py index 11373be5..e8d70498 100644 --- a/tests/auth/test_auth_validator.py +++ b/tests/auth/test_auth_validator.py @@ -1,4 +1,5 @@ import json +import threading import uuid from dataclasses import dataclass from datetime import datetime, timedelta @@ -18,12 +19,18 @@ from fastapi import Depends, FastAPI, HTTPException from fastapi.security import HTTPAuthorizationCredentials from fastapi.testclient import TestClient -from httpx import Response +from httpx import Request, Response from jwt.algorithms import RSAAlgorithm from auth import auth0_security, get_auth0_token from auth.user_permissions import user_is_general_admin -from auth.validator import get_rsa_key, verify_action_token, verify_jwt +from auth.validator import ( + KEY_CACHE, + _fetch_rsa_keys, + get_rsa_key, + verify_action_token, + verify_jwt, +) from config import Settings from db.models import BiocommonsUser from tests.datagen import AccessTokenPayloadFactory, SessionUserFactory @@ -125,11 +132,17 @@ def create_access_token( def test_get_rsa_key_returns_key(mock_settings: Settings): token = jwt.encode({"some": "payload"}, TEST_HS256_SECRET, algorithm="HS256") unverified_header = {"kid": "testkey"} + metadata_url = f"https://{mock_settings.auth0_domain}/.well-known/openid-configuration" + jwks_url = f"https://{mock_settings.auth0_domain}/.well-known/jwks.json" with patch("auth.validator.jwt.get_unverified_header", return_value=unverified_header), \ patch("auth.validator.httpx.get") as mock_get: - - mock_get.return_value.json.return_value = { + metadata_response = Response( + 200, + json={"jwks_uri": jwks_url}, + request=Request("GET", metadata_url), + ) + jwks_response = Response(200, json={ "keys": [{ "kid": "testkey", "kty": "RSA", @@ -137,10 +150,13 @@ def test_get_rsa_key_returns_key(mock_settings: Settings): "n": "sXchfZm9UOCNHQ", # base64url-encoded dummy values "e": "AQAB" }] - } + }, request=Request("GET", jwks_url)) + mock_get.side_effect = [metadata_response, jwks_response] key = get_rsa_key(token, settings=mock_settings) assert key is not None + assert mock_get.call_args_list[0].args == (metadata_url,) + assert mock_get.call_args_list[1].args == (jwks_url,) def generate_dummy_rsa_key(key_id: str) -> dict: @@ -176,8 +192,15 @@ def test_get_rsa_key_retry_on_failure(mock_settings: Settings): "keys": [other_key, missing_key] } + metadata_url = f"https://{mock_settings.auth0_domain}/.well-known/openid-configuration" jwks_url = f"https://{mock_settings.auth0_domain}/.well-known/jwks.json" - route = respx.get(jwks_url).mock( + metadata_route = respx.get(metadata_url).mock( + side_effect=[ + Response(200, json={"jwks_uri": jwks_url}), + Response(200, json={"jwks_uri": jwks_url}), + ] + ) + jwks_route = respx.get(jwks_url).mock( side_effect=[ Response(200, json=cached_jwks), Response(200, json=fresh_jwks) @@ -192,8 +215,9 @@ def test_get_rsa_key_retry_on_failure(mock_settings: Settings): key = get_rsa_key(token, settings=mock_settings) # Verify the key was found after retry assert key is not None - # Verify that the endpoint was called twice (cached + fresh) - assert route.call_count == 2 + # Verify that metadata and key lookups were retried after cache clear. + assert metadata_route.call_count == 2 + assert jwks_route.call_count == 2 @respx.mock @@ -209,8 +233,12 @@ def test_get_rsa_key_no_retry_needed_when_key_found_first_time(mock_settings: Se "keys": [found_key] } + metadata_url = f"https://{mock_settings.auth0_domain}/.well-known/openid-configuration" jwks_url = f"https://{mock_settings.auth0_domain}/.well-known/jwks.json" - route = respx.get(jwks_url).mock(return_value=Response(200, json=jwks_response)) + metadata_route = respx.get(metadata_url).mock( + return_value=Response(200, json={"jwks_uri": jwks_url}) + ) + jwks_route = respx.get(jwks_url).mock(return_value=Response(200, json=jwks_response)) # Clear the cache before the test to ensure clean state from auth.validator import KEY_CACHE @@ -223,8 +251,9 @@ def test_get_rsa_key_no_retry_needed_when_key_found_first_time(mock_settings: Se # Verify key was found assert key is not None - # Verify endpoint was called only once (no retry needed) - assert route.call_count == 1 + # Verify metadata and JWKS were each fetched once. + assert metadata_route.call_count == 1 + assert jwks_route.call_count == 1 def test_auth0_security_passes_bearer_token_to_route(): @@ -548,3 +577,61 @@ def test_verify_action_token_missing_exp(mock_settings: Settings): verify_action_token(token, mock_settings) assert excinfo.value.status_code == 401 assert excinfo.value.detail == "invalid session_token" + + +@respx.mock +def test_fetch_rsa_keys_only_refreshes_once_when_cache_is_expired(mock_settings: Settings): + """ + Demonstrate that concurrent requests only trigger one JWKS refresh. + """ + KEY_CACHE.clear() + + metadata_url = f"https://{mock_settings.auth0_domain}/.well-known/openid-configuration" + jwks_url = f"https://{mock_settings.auth0_domain}/.well-known/jwks.json" + jwks_response = {"keys": [generate_dummy_rsa_key("test-key")]} + + start_gate = threading.Barrier(5) + first_request_started = threading.Event() + release_refresh = threading.Event() + call_count = 0 + call_count_lock = threading.Lock() + + def slow_jwks_response(*args, **kwargs): + nonlocal call_count + with call_count_lock: + call_count += 1 + first_request_started.set() + release_refresh.wait() + return Response(200, json=jwks_response) + + metadata_route = respx.get(metadata_url).mock( + return_value=Response(200, json={"jwks_uri": jwks_url}) + ) + jwks_route = respx.get(jwks_url).mock(side_effect=slow_jwks_response) + + results = [] + results_lock = threading.Lock() + + def worker(): + start_gate.wait(timeout=10) + result = _fetch_rsa_keys(mock_settings.auth0_domain) + with results_lock: + results.append(result) + + threads = [threading.Thread(target=worker) for _ in range(5)] + for thread in threads: + thread.start() + + first_request_started.wait() + release_refresh.set() + + for thread in threads: + thread.join(timeout=10) + + # Check all calls went through + assert len(results) == 5 + assert all(result == jwks_response for result in results) + # Check the OIDC metadata and JWKS endpoints were each only called once. + assert metadata_route.call_count == 1 + assert jwks_route.call_count == 1 + assert call_count == 1 diff --git a/uv.lock b/uv.lock index 450f7a17..aaf10ac6 100644 --- a/uv.lock +++ b/uv.lock @@ -8,7 +8,7 @@ resolution-markers = [ [[package]] name = "aai-backend" -version = "1.1.2" +version = "1.1.3" source = { editable = "." } dependencies = [ { name = "alembic" },