From af642827f7d760823e2c4e064549100b0fda6c39 Mon Sep 17 00:00:00 2001 From: marius-mather Date: Mon, 30 Mar 2026 14:53:48 +1100 Subject: [PATCH 01/10] fix: use a Lock so we can prevent multiple simultaneous fetches of JWKS keys --- auth/validator.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/auth/validator.py b/auth/validator.py index ae4554a9..a0b55087 100644 --- a/auth/validator.py +++ b/auth/validator.py @@ -1,5 +1,6 @@ import json import logging +from threading import Lock import httpx import jwt @@ -14,6 +15,7 @@ logger = logging.getLogger("uvicorn.error") KEY_CACHE = TTLCache(maxsize=10, ttl=30 * 60) +KEY_CACHE_LOCK = Lock() def verify_jwt(token: str, settings: Settings) -> AccessTokenPayload: From 9c48ae7671a9d3e1f6a29dc6c410b7ec0ece2d60 Mon Sep 17 00:00:00 2001 From: marius-mather Date: Mon, 30 Mar 2026 14:56:09 +1100 Subject: [PATCH 02/10] fix: increase JWKS cache timeout to 6 hours --- auth/validator.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/auth/validator.py b/auth/validator.py index a0b55087..b48511d9 100644 --- a/auth/validator.py +++ b/auth/validator.py @@ -14,7 +14,8 @@ logger = logging.getLogger("uvicorn.error") -KEY_CACHE = TTLCache(maxsize=10, ttl=30 * 60) +KEY_CACHE_TIMEOUT = 6 * 60 * 60 # 6 hours +KEY_CACHE = TTLCache(maxsize=10, ttl=KEY_CACHE_TIMEOUT) KEY_CACHE_LOCK = Lock() From 868c16feb7f3edd2af429fb545bb94cbce703901 Mon Sep 17 00:00:00 2001 From: marius-mather Date: Mon, 30 Mar 2026 14:58:45 +1100 Subject: [PATCH 03/10] fix: lock when fetching JWKs to avoid multiple requests --- auth/validator.py | 24 +++++++++++++++++++----- 1 file changed, 19 insertions(+), 5 deletions(-) diff --git a/auth/validator.py b/auth/validator.py index b48511d9..0058efd1 100644 --- a/auth/validator.py +++ b/auth/validator.py @@ -70,14 +70,28 @@ def verify_jwt(token: str, settings: Settings) -> AccessTokenPayload: def _fetch_rsa_keys(auth0_domain: str) -> dict: + """ + Try to get cached keys if possible, otherwise + refresh from Auth0 + """ cache_key = f"jwks_{auth0_domain}" if cache_key in KEY_CACHE: return KEY_CACHE[cache_key] - jwks_url = f"https://{auth0_domain}/.well-known/jwks.json" - response = httpx.get(jwks_url) - keys = response.json() - KEY_CACHE[cache_key] = keys - return keys + + # Lock so we don't do the lookup multiple times + # if multiple requests come in while cache is expired + with KEY_CACHE_LOCK: + # Check again: another request may have refreshed while + # this was waiting + cached = KEY_CACHE.get(cache_key, None) + if cached is not None: + return cached + + jwks_url = f"https://{auth0_domain}/.well-known/jwks.json" + response = httpx.get(jwks_url) + keys = response.json() + KEY_CACHE[cache_key] = keys + return keys def get_rsa_key(token: str, settings: Settings, retry_on_failure: bool = True): From 52431ba8b9820158045ba8be30c79667ad1376a8 Mon Sep 17 00:00:00 2001 From: marius-mather Date: Mon, 30 Mar 2026 14:59:19 +1100 Subject: [PATCH 04/10] fix: lock when clearing key cache --- auth/validator.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/auth/validator.py b/auth/validator.py index 0058efd1..9d150478 100644 --- a/auth/validator.py +++ b/auth/validator.py @@ -105,9 +105,10 @@ def get_rsa_key(token: str, settings: Settings, retry_on_failure: bool = True): if key.get("kid") == key_id: return RSAAlgorithm.from_jwk(json.dumps(key)) - # Retry without cache on failure + # Retry without cache on failure (but only once, to prevent infinite retry) if retry_on_failure: - KEY_CACHE.clear() + with KEY_CACHE_LOCK: + KEY_CACHE.clear() return get_rsa_key(token, settings, retry_on_failure=False) return None From a85fb8831e6479d48071b15f14d88955544dfba9 Mon Sep 17 00:00:00 2001 From: marius-mather Date: Mon, 30 Mar 2026 15:00:19 +1100 Subject: [PATCH 05/10] test: add test of locking/caching in jwks lookup --- tests/auth/test_auth_validator.py | 60 ++++++++++++++++++++++++++++++- 1 file changed, 59 insertions(+), 1 deletion(-) diff --git a/tests/auth/test_auth_validator.py b/tests/auth/test_auth_validator.py index 11373be5..a80ca890 100644 --- a/tests/auth/test_auth_validator.py +++ b/tests/auth/test_auth_validator.py @@ -1,4 +1,6 @@ import json +import threading +import time import uuid from dataclasses import dataclass from datetime import datetime, timedelta @@ -23,7 +25,13 @@ from auth import auth0_security, get_auth0_token from auth.user_permissions import user_is_general_admin -from auth.validator import get_rsa_key, verify_action_token, verify_jwt +from auth.validator import ( + KEY_CACHE, + _fetch_rsa_keys, + get_rsa_key, + verify_action_token, + verify_jwt, +) from config import Settings from db.models import BiocommonsUser from tests.datagen import AccessTokenPayloadFactory, SessionUserFactory @@ -548,3 +556,53 @@ def test_verify_action_token_missing_exp(mock_settings: Settings): verify_action_token(token, mock_settings) assert excinfo.value.status_code == 401 assert excinfo.value.detail == "invalid session_token" + + +@respx.mock +def test_fetch_rsa_keys_only_refreshes_once_when_cache_is_expired(mock_settings: Settings): + """ + Demonstrate that concurrent requests only trigger one JWKS refresh. + """ + KEY_CACHE.clear() + + jwks_url = f"https://{mock_settings.auth0_domain}/.well-known/jwks.json" + jwks_response = {"keys": [generate_dummy_rsa_key("test-key")]} + + start_gate = threading.Barrier(5) + release_refresh = threading.Event() + call_count = 0 + call_count_lock = threading.Lock() + + def slow_response(*args, **kwargs): + nonlocal call_count + with call_count_lock: + call_count += 1 + release_refresh.wait(timeout=2) + return Response(200, json=jwks_response) + + respx.get(jwks_url).mock(side_effect=slow_response) + + results = [] + results_lock = threading.Lock() + + def worker(): + start_gate.wait(timeout=2) + result = _fetch_rsa_keys(mock_settings.auth0_domain) + with results_lock: + results.append(result) + + threads = [threading.Thread(target=worker) for _ in range(5)] + for thread in threads: + thread.start() + + time.sleep(0.1) + release_refresh.set() + + for thread in threads: + thread.join(timeout=2) + + # Check all calls went through + assert len(results) == 5 + assert all(result == jwks_response for result in results) + # Check the Auth0 API was only called once + assert call_count == 1 From e07290f10ca006cdfcea2f4258133f80b98c75b9 Mon Sep 17 00:00:00 2001 From: marius-mather Date: Mon, 30 Mar 2026 15:00:37 +1100 Subject: [PATCH 06/10] chore: update lock file to reflect backend version --- uv.lock | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/uv.lock b/uv.lock index 450f7a17..aaf10ac6 100644 --- a/uv.lock +++ b/uv.lock @@ -8,7 +8,7 @@ resolution-markers = [ [[package]] name = "aai-backend" -version = "1.1.2" +version = "1.1.3" source = { editable = "." } dependencies = [ { name = "alembic" }, From 58433315c5aae3709aad4a49243bd6739015567d Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Mon, 30 Mar 2026 04:14:03 +0000 Subject: [PATCH 07/10] test: replace timing-sensitive sleep with events --- tests/auth/test_auth_validator.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/tests/auth/test_auth_validator.py b/tests/auth/test_auth_validator.py index a80ca890..e4d70942 100644 --- a/tests/auth/test_auth_validator.py +++ b/tests/auth/test_auth_validator.py @@ -1,6 +1,5 @@ import json import threading -import time import uuid from dataclasses import dataclass from datetime import datetime, timedelta @@ -569,6 +568,7 @@ def test_fetch_rsa_keys_only_refreshes_once_when_cache_is_expired(mock_settings: jwks_response = {"keys": [generate_dummy_rsa_key("test-key")]} start_gate = threading.Barrier(5) + first_request_started = threading.Event() release_refresh = threading.Event() call_count = 0 call_count_lock = threading.Lock() @@ -577,7 +577,8 @@ def slow_response(*args, **kwargs): nonlocal call_count with call_count_lock: call_count += 1 - release_refresh.wait(timeout=2) + first_request_started.set() + release_refresh.wait() return Response(200, json=jwks_response) respx.get(jwks_url).mock(side_effect=slow_response) @@ -586,7 +587,7 @@ def slow_response(*args, **kwargs): results_lock = threading.Lock() def worker(): - start_gate.wait(timeout=2) + start_gate.wait(timeout=10) result = _fetch_rsa_keys(mock_settings.auth0_domain) with results_lock: results.append(result) @@ -595,11 +596,11 @@ def worker(): for thread in threads: thread.start() - time.sleep(0.1) + first_request_started.wait() release_refresh.set() for thread in threads: - thread.join(timeout=2) + thread.join(timeout=10) # Check all calls went through assert len(results) == 5 From cb1ce0787024d5752d88639ad8e45edcbf8d6f07 Mon Sep 17 00:00:00 2001 From: marius-mather Date: Mon, 30 Mar 2026 15:31:18 +1100 Subject: [PATCH 08/10] fix: check for errors when fetching JWKs --- auth/validator.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/auth/validator.py b/auth/validator.py index 9d150478..9aa4a17a 100644 --- a/auth/validator.py +++ b/auth/validator.py @@ -89,7 +89,13 @@ def _fetch_rsa_keys(auth0_domain: str) -> dict: jwks_url = f"https://{auth0_domain}/.well-known/jwks.json" response = httpx.get(jwks_url) - keys = response.json() + try: + response.raise_for_status() + keys = response.json() + except (httpx.HTTPError, ValueError) as exc: + logger.error(f"Failed to fetch JWKS from {jwks_url}: {exc}") + # Do not cache on error + raise InvalidTokenError("Failed to fetch JWKS") from exc KEY_CACHE[cache_key] = keys return keys From 04ba1cdcaa8d88865d93adadab880d7be90e3084 Mon Sep 17 00:00:00 2001 From: marius-mather Date: Tue, 31 Mar 2026 10:34:49 +1100 Subject: [PATCH 09/10] fix: use openid-configuration to get jwks_uri --- auth/validator.py | 16 +++++++-- tests/auth/test_auth_validator.py | 54 +++++++++++++++++++++++-------- 2 files changed, 54 insertions(+), 16 deletions(-) diff --git a/auth/validator.py b/auth/validator.py index 9aa4a17a..9b465661 100644 --- a/auth/validator.py +++ b/auth/validator.py @@ -87,13 +87,23 @@ def _fetch_rsa_keys(auth0_domain: str) -> dict: if cached is not None: return cached - jwks_url = f"https://{auth0_domain}/.well-known/jwks.json" - response = httpx.get(jwks_url) try: + metadata_url = f"https://{auth0_domain}/.well-known/openid-configuration" + metadata_response = httpx.get(metadata_url) + metadata_response.raise_for_status() + metadata = metadata_response.json() + + jwks_url = metadata["jwks_uri"] + response = httpx.get(jwks_url) response.raise_for_status() keys = response.json() + except KeyError as exc: + logger.error(f"OIDC metadata from {metadata_url} did not include jwks_uri") + raise InvalidTokenError("Failed to fetch JWKS") from exc except (httpx.HTTPError, ValueError) as exc: - logger.error(f"Failed to fetch JWKS from {jwks_url}: {exc}") + logger.error( + f"Failed to fetch OIDC metadata or JWKS for domain {auth0_domain}: {exc}" + ) # Do not cache on error raise InvalidTokenError("Failed to fetch JWKS") from exc KEY_CACHE[cache_key] = keys diff --git a/tests/auth/test_auth_validator.py b/tests/auth/test_auth_validator.py index e4d70942..e8d70498 100644 --- a/tests/auth/test_auth_validator.py +++ b/tests/auth/test_auth_validator.py @@ -19,7 +19,7 @@ from fastapi import Depends, FastAPI, HTTPException from fastapi.security import HTTPAuthorizationCredentials from fastapi.testclient import TestClient -from httpx import Response +from httpx import Request, Response from jwt.algorithms import RSAAlgorithm from auth import auth0_security, get_auth0_token @@ -132,11 +132,17 @@ def create_access_token( def test_get_rsa_key_returns_key(mock_settings: Settings): token = jwt.encode({"some": "payload"}, TEST_HS256_SECRET, algorithm="HS256") unverified_header = {"kid": "testkey"} + metadata_url = f"https://{mock_settings.auth0_domain}/.well-known/openid-configuration" + jwks_url = f"https://{mock_settings.auth0_domain}/.well-known/jwks.json" with patch("auth.validator.jwt.get_unverified_header", return_value=unverified_header), \ patch("auth.validator.httpx.get") as mock_get: - - mock_get.return_value.json.return_value = { + metadata_response = Response( + 200, + json={"jwks_uri": jwks_url}, + request=Request("GET", metadata_url), + ) + jwks_response = Response(200, json={ "keys": [{ "kid": "testkey", "kty": "RSA", @@ -144,10 +150,13 @@ def test_get_rsa_key_returns_key(mock_settings: Settings): "n": "sXchfZm9UOCNHQ", # base64url-encoded dummy values "e": "AQAB" }] - } + }, request=Request("GET", jwks_url)) + mock_get.side_effect = [metadata_response, jwks_response] key = get_rsa_key(token, settings=mock_settings) assert key is not None + assert mock_get.call_args_list[0].args == (metadata_url,) + assert mock_get.call_args_list[1].args == (jwks_url,) def generate_dummy_rsa_key(key_id: str) -> dict: @@ -183,8 +192,15 @@ def test_get_rsa_key_retry_on_failure(mock_settings: Settings): "keys": [other_key, missing_key] } + metadata_url = f"https://{mock_settings.auth0_domain}/.well-known/openid-configuration" jwks_url = f"https://{mock_settings.auth0_domain}/.well-known/jwks.json" - route = respx.get(jwks_url).mock( + metadata_route = respx.get(metadata_url).mock( + side_effect=[ + Response(200, json={"jwks_uri": jwks_url}), + Response(200, json={"jwks_uri": jwks_url}), + ] + ) + jwks_route = respx.get(jwks_url).mock( side_effect=[ Response(200, json=cached_jwks), Response(200, json=fresh_jwks) @@ -199,8 +215,9 @@ def test_get_rsa_key_retry_on_failure(mock_settings: Settings): key = get_rsa_key(token, settings=mock_settings) # Verify the key was found after retry assert key is not None - # Verify that the endpoint was called twice (cached + fresh) - assert route.call_count == 2 + # Verify that metadata and key lookups were retried after cache clear. + assert metadata_route.call_count == 2 + assert jwks_route.call_count == 2 @respx.mock @@ -216,8 +233,12 @@ def test_get_rsa_key_no_retry_needed_when_key_found_first_time(mock_settings: Se "keys": [found_key] } + metadata_url = f"https://{mock_settings.auth0_domain}/.well-known/openid-configuration" jwks_url = f"https://{mock_settings.auth0_domain}/.well-known/jwks.json" - route = respx.get(jwks_url).mock(return_value=Response(200, json=jwks_response)) + metadata_route = respx.get(metadata_url).mock( + return_value=Response(200, json={"jwks_uri": jwks_url}) + ) + jwks_route = respx.get(jwks_url).mock(return_value=Response(200, json=jwks_response)) # Clear the cache before the test to ensure clean state from auth.validator import KEY_CACHE @@ -230,8 +251,9 @@ def test_get_rsa_key_no_retry_needed_when_key_found_first_time(mock_settings: Se # Verify key was found assert key is not None - # Verify endpoint was called only once (no retry needed) - assert route.call_count == 1 + # Verify metadata and JWKS were each fetched once. + assert metadata_route.call_count == 1 + assert jwks_route.call_count == 1 def test_auth0_security_passes_bearer_token_to_route(): @@ -564,6 +586,7 @@ def test_fetch_rsa_keys_only_refreshes_once_when_cache_is_expired(mock_settings: """ KEY_CACHE.clear() + metadata_url = f"https://{mock_settings.auth0_domain}/.well-known/openid-configuration" jwks_url = f"https://{mock_settings.auth0_domain}/.well-known/jwks.json" jwks_response = {"keys": [generate_dummy_rsa_key("test-key")]} @@ -573,7 +596,7 @@ def test_fetch_rsa_keys_only_refreshes_once_when_cache_is_expired(mock_settings: call_count = 0 call_count_lock = threading.Lock() - def slow_response(*args, **kwargs): + def slow_jwks_response(*args, **kwargs): nonlocal call_count with call_count_lock: call_count += 1 @@ -581,7 +604,10 @@ def slow_response(*args, **kwargs): release_refresh.wait() return Response(200, json=jwks_response) - respx.get(jwks_url).mock(side_effect=slow_response) + metadata_route = respx.get(metadata_url).mock( + return_value=Response(200, json={"jwks_uri": jwks_url}) + ) + jwks_route = respx.get(jwks_url).mock(side_effect=slow_jwks_response) results = [] results_lock = threading.Lock() @@ -605,5 +631,7 @@ def worker(): # Check all calls went through assert len(results) == 5 assert all(result == jwks_response for result in results) - # Check the Auth0 API was only called once + # Check the OIDC metadata and JWKS endpoints were each only called once. + assert metadata_route.call_count == 1 + assert jwks_route.call_count == 1 assert call_count == 1 From b7802757ab4fb451e4296f24c6072965fef96b24 Mon Sep 17 00:00:00 2001 From: marius-mather Date: Tue, 31 Mar 2026 10:46:25 +1100 Subject: [PATCH 10/10] fix: don't log on every invalid issuer, only if all are invalid --- auth/validator.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/auth/validator.py b/auth/validator.py index 9b465661..f4d9280e 100644 --- a/auth/validator.py +++ b/auth/validator.py @@ -39,6 +39,7 @@ def verify_jwt(token: str, settings: Settings) -> AccessTokenPayload: issuers.append(settings.auth0_issuer) payload = None + last_issuer_error = None for issuer in issuers: try: payload = jwt.decode( @@ -50,14 +51,16 @@ def verify_jwt(token: str, settings: Settings) -> AccessTokenPayload: ) break except InvalidIssuerError as e: - logger.warning(f"JWT rejected due to invalid issuer: {e}") + last_issuer_error = e continue except InvalidTokenError as e: logger.warning(f"JWT rejected during decode: {e}") raise HTTPException(status_code=401, detail="Not authorized") if payload is None: - logger.warning("JWT rejected: issuer validation failed for all configured issuers") + logger.warning( + f"JWT rejected: issuer validation failed for all configured issuers: {last_issuer_error}" + ) raise HTTPException(status_code=401, detail="Not authorized") roles_claim = "https://biocommons.org.au/roles"