Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 47 additions & 10 deletions auth/validator.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import json
import logging
from threading import Lock

import httpx
import jwt
Expand All @@ -13,7 +14,9 @@

logger = logging.getLogger("uvicorn.error")

KEY_CACHE = TTLCache(maxsize=10, ttl=30 * 60)
KEY_CACHE_TIMEOUT = 6 * 60 * 60 # 6 hours
KEY_CACHE = TTLCache(maxsize=10, ttl=KEY_CACHE_TIMEOUT)
KEY_CACHE_LOCK = Lock()


def verify_jwt(token: str, settings: Settings) -> AccessTokenPayload:
Expand All @@ -36,6 +39,7 @@ def verify_jwt(token: str, settings: Settings) -> AccessTokenPayload:
issuers.append(settings.auth0_issuer)

payload = None
last_issuer_error = None
for issuer in issuers:
try:
payload = jwt.decode(
Expand All @@ -47,14 +51,16 @@ def verify_jwt(token: str, settings: Settings) -> AccessTokenPayload:
)
break
except InvalidIssuerError as e:
logger.warning(f"JWT rejected due to invalid issuer: {e}")
last_issuer_error = e
continue
except InvalidTokenError as e:
logger.warning(f"JWT rejected during decode: {e}")
raise HTTPException(status_code=401, detail="Not authorized")

if payload is None:
logger.warning("JWT rejected: issuer validation failed for all configured issuers")
logger.warning(
f"JWT rejected: issuer validation failed for all configured issuers: {last_issuer_error}"
)
raise HTTPException(status_code=401, detail="Not authorized")

roles_claim = "https://biocommons.org.au/roles"
Expand All @@ -67,14 +73,44 @@ def verify_jwt(token: str, settings: Settings) -> AccessTokenPayload:


def _fetch_rsa_keys(auth0_domain: str) -> dict:
"""
Try to get cached keys if possible, otherwise
refresh from Auth0
"""
cache_key = f"jwks_{auth0_domain}"
if cache_key in KEY_CACHE:
return KEY_CACHE[cache_key]
Comment thread
marius-mather marked this conversation as resolved.
jwks_url = f"https://{auth0_domain}/.well-known/jwks.json"
response = httpx.get(jwks_url)
keys = response.json()
KEY_CACHE[cache_key] = keys
return keys

# Lock so we don't do the lookup multiple times
# if multiple requests come in while cache is expired
with KEY_CACHE_LOCK:
# Check again: another request may have refreshed while
# this was waiting
cached = KEY_CACHE.get(cache_key, None)
if cached is not None:
return cached

try:
metadata_url = f"https://{auth0_domain}/.well-known/openid-configuration"
metadata_response = httpx.get(metadata_url)
metadata_response.raise_for_status()
metadata = metadata_response.json()

jwks_url = metadata["jwks_uri"]
response = httpx.get(jwks_url)
response.raise_for_status()
keys = response.json()
except KeyError as exc:
logger.error(f"OIDC metadata from {metadata_url} did not include jwks_uri")
raise InvalidTokenError("Failed to fetch JWKS") from exc
except (httpx.HTTPError, ValueError) as exc:
logger.error(
f"Failed to fetch OIDC metadata or JWKS for domain {auth0_domain}: {exc}"
)
# Do not cache on error
raise InvalidTokenError("Failed to fetch JWKS") from exc
KEY_CACHE[cache_key] = keys
return keys


def get_rsa_key(token: str, settings: Settings, retry_on_failure: bool = True):
Expand All @@ -88,9 +124,10 @@ def get_rsa_key(token: str, settings: Settings, retry_on_failure: bool = True):
if key.get("kid") == key_id:
return RSAAlgorithm.from_jwk(json.dumps(key))

# Retry without cache on failure
# Retry without cache on failure (but only once, to prevent infinite retry)
if retry_on_failure:
KEY_CACHE.clear()
with KEY_CACHE_LOCK:
KEY_CACHE.clear()
return get_rsa_key(token, settings, retry_on_failure=False)

return None
Expand Down
109 changes: 98 additions & 11 deletions tests/auth/test_auth_validator.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import json
import threading
import uuid
from dataclasses import dataclass
from datetime import datetime, timedelta
Expand All @@ -18,12 +19,18 @@
from fastapi import Depends, FastAPI, HTTPException
from fastapi.security import HTTPAuthorizationCredentials
from fastapi.testclient import TestClient
from httpx import Response
from httpx import Request, Response
from jwt.algorithms import RSAAlgorithm

from auth import auth0_security, get_auth0_token
from auth.user_permissions import user_is_general_admin
from auth.validator import get_rsa_key, verify_action_token, verify_jwt
from auth.validator import (
KEY_CACHE,
_fetch_rsa_keys,
get_rsa_key,
verify_action_token,
verify_jwt,
)
from config import Settings
from db.models import BiocommonsUser
from tests.datagen import AccessTokenPayloadFactory, SessionUserFactory
Expand Down Expand Up @@ -125,22 +132,31 @@ def create_access_token(
def test_get_rsa_key_returns_key(mock_settings: Settings):
token = jwt.encode({"some": "payload"}, TEST_HS256_SECRET, algorithm="HS256")
unverified_header = {"kid": "testkey"}
metadata_url = f"https://{mock_settings.auth0_domain}/.well-known/openid-configuration"
jwks_url = f"https://{mock_settings.auth0_domain}/.well-known/jwks.json"

with patch("auth.validator.jwt.get_unverified_header", return_value=unverified_header), \
patch("auth.validator.httpx.get") as mock_get:

mock_get.return_value.json.return_value = {
metadata_response = Response(
200,
json={"jwks_uri": jwks_url},
request=Request("GET", metadata_url),
)
jwks_response = Response(200, json={
"keys": [{
"kid": "testkey",
"kty": "RSA",
"alg": "RS256",
"n": "sXchfZm9UOCNHQ", # base64url-encoded dummy values
"e": "AQAB"
}]
}
}, request=Request("GET", jwks_url))
mock_get.side_effect = [metadata_response, jwks_response]

key = get_rsa_key(token, settings=mock_settings)
assert key is not None
assert mock_get.call_args_list[0].args == (metadata_url,)
assert mock_get.call_args_list[1].args == (jwks_url,)


def generate_dummy_rsa_key(key_id: str) -> dict:
Expand Down Expand Up @@ -176,8 +192,15 @@ def test_get_rsa_key_retry_on_failure(mock_settings: Settings):
"keys": [other_key, missing_key]
}

metadata_url = f"https://{mock_settings.auth0_domain}/.well-known/openid-configuration"
jwks_url = f"https://{mock_settings.auth0_domain}/.well-known/jwks.json"
route = respx.get(jwks_url).mock(
metadata_route = respx.get(metadata_url).mock(
side_effect=[
Response(200, json={"jwks_uri": jwks_url}),
Response(200, json={"jwks_uri": jwks_url}),
]
)
jwks_route = respx.get(jwks_url).mock(
side_effect=[
Response(200, json=cached_jwks),
Response(200, json=fresh_jwks)
Expand All @@ -192,8 +215,9 @@ def test_get_rsa_key_retry_on_failure(mock_settings: Settings):
key = get_rsa_key(token, settings=mock_settings)
# Verify the key was found after retry
assert key is not None
# Verify that the endpoint was called twice (cached + fresh)
assert route.call_count == 2
# Verify that metadata and key lookups were retried after cache clear.
assert metadata_route.call_count == 2
assert jwks_route.call_count == 2


@respx.mock
Expand All @@ -209,8 +233,12 @@ def test_get_rsa_key_no_retry_needed_when_key_found_first_time(mock_settings: Se
"keys": [found_key]
}

metadata_url = f"https://{mock_settings.auth0_domain}/.well-known/openid-configuration"
jwks_url = f"https://{mock_settings.auth0_domain}/.well-known/jwks.json"
route = respx.get(jwks_url).mock(return_value=Response(200, json=jwks_response))
metadata_route = respx.get(metadata_url).mock(
return_value=Response(200, json={"jwks_uri": jwks_url})
)
jwks_route = respx.get(jwks_url).mock(return_value=Response(200, json=jwks_response))

# Clear the cache before the test to ensure clean state
from auth.validator import KEY_CACHE
Expand All @@ -223,8 +251,9 @@ def test_get_rsa_key_no_retry_needed_when_key_found_first_time(mock_settings: Se
# Verify key was found
assert key is not None

# Verify endpoint was called only once (no retry needed)
assert route.call_count == 1
# Verify metadata and JWKS were each fetched once.
assert metadata_route.call_count == 1
assert jwks_route.call_count == 1


def test_auth0_security_passes_bearer_token_to_route():
Expand Down Expand Up @@ -548,3 +577,61 @@ def test_verify_action_token_missing_exp(mock_settings: Settings):
verify_action_token(token, mock_settings)
assert excinfo.value.status_code == 401
assert excinfo.value.detail == "invalid session_token"


@respx.mock
def test_fetch_rsa_keys_only_refreshes_once_when_cache_is_expired(mock_settings: Settings):
"""
Demonstrate that concurrent requests only trigger one JWKS refresh.
"""
KEY_CACHE.clear()

metadata_url = f"https://{mock_settings.auth0_domain}/.well-known/openid-configuration"
jwks_url = f"https://{mock_settings.auth0_domain}/.well-known/jwks.json"
jwks_response = {"keys": [generate_dummy_rsa_key("test-key")]}

start_gate = threading.Barrier(5)
first_request_started = threading.Event()
release_refresh = threading.Event()
call_count = 0
call_count_lock = threading.Lock()

def slow_jwks_response(*args, **kwargs):
nonlocal call_count
with call_count_lock:
call_count += 1
first_request_started.set()
release_refresh.wait()
return Response(200, json=jwks_response)

metadata_route = respx.get(metadata_url).mock(
return_value=Response(200, json={"jwks_uri": jwks_url})
)
jwks_route = respx.get(jwks_url).mock(side_effect=slow_jwks_response)

results = []
results_lock = threading.Lock()

def worker():
start_gate.wait(timeout=10)
result = _fetch_rsa_keys(mock_settings.auth0_domain)
with results_lock:
results.append(result)

threads = [threading.Thread(target=worker) for _ in range(5)]
for thread in threads:
thread.start()

first_request_started.wait()
release_refresh.set()

for thread in threads:
thread.join(timeout=10)

# Check all calls went through
assert len(results) == 5
assert all(result == jwks_response for result in results)
# Check the OIDC metadata and JWKS endpoints were each only called once.
assert metadata_route.call_count == 1
assert jwks_route.call_count == 1
assert call_count == 1
2 changes: 1 addition & 1 deletion uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading