diff --git a/auth/validator.py b/auth/validator.py index d3cac112..29a35fcd 100644 --- a/auth/validator.py +++ b/auth/validator.py @@ -1,7 +1,7 @@ -from functools import lru_cache from typing import Annotated import httpx +from cachetools import TTLCache from fastapi import Depends, HTTPException, status from fastapi.security import OAuth2PasswordBearer from jose import jwk, jwt @@ -13,6 +13,8 @@ oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token") +KEY_CACHE = TTLCache(maxsize=10, ttl=30 * 60) + def verify_jwt(token: str, settings: Settings) -> AccessTokenPayload: try: @@ -50,14 +52,18 @@ def verify_jwt(token: str, settings: Settings) -> AccessTokenPayload: return AccessTokenPayload(**payload) -@lru_cache(maxsize=100) def _fetch_rsa_keys(auth0_domain: str) -> dict: + cache_key = f"jwks_{auth0_domain}" + if cache_key in KEY_CACHE: + return KEY_CACHE[cache_key] jwks_url = f"https://{auth0_domain}/.well-known/jwks.json" response = httpx.get(jwks_url) - return response.json() + keys = response.json() + KEY_CACHE[cache_key] = keys + return keys -def get_rsa_key(token: str, settings: Settings) -> jwk.RSAKey | None: # type: ignore +def get_rsa_key(token: str, settings: Settings, retry_on_failure: bool = True) -> jwk.RSAKey | None: # type: ignore jwks = _fetch_rsa_keys(settings.auth0_domain) unverified_header = jwt.get_unverified_header(token) @@ -65,6 +71,11 @@ def get_rsa_key(token: str, settings: Settings) -> jwk.RSAKey | None: # type: i if key["kid"] == unverified_header["kid"]: return jwk.construct(key) + # Retry without cache on failure + if retry_on_failure: + KEY_CACHE.clear() + return get_rsa_key(token, settings, retry_on_failure=False) + return None diff --git a/pyproject.toml b/pyproject.toml index 64c03a26..915d60fa 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -18,6 +18,7 @@ dependencies = [ "itsdangerous>=2.2.0", "apscheduler[redis,sqlalchemy]~=3.11", "loguru>=0.7.3", + "cachetools~=6.2", ] [project.optional-dependencies] diff --git a/tests/test_auth_validator.py b/tests/test_auth_validator.py index 703dbdcd..89b61878 100644 --- a/tests/test_auth_validator.py +++ b/tests/test_auth_validator.py @@ -5,6 +5,7 @@ from unittest.mock import patch import pytest +import respx # Tools from hazmat should only be used for testing! from cryptography.hazmat.primitives.asymmetric import rsa @@ -13,7 +14,8 @@ RSAPublicKey, ) from fastapi import HTTPException -from jose import jwt +from httpx import Response +from jose import jwk, jwt from jose.backends.cryptography_backend import CryptographyRSAKey from auth.validator import get_rsa_key, verify_jwt @@ -130,6 +132,95 @@ def test_get_rsa_key_returns_key(mock_settings: Settings): assert isinstance(key, CryptographyRSAKey) +def generate_dummy_rsa_key(key_id: str) -> dict: + """Generate a test RSA key in JWKS format using jose.""" + # Generate RSA key pair + private_key = rsa.generate_private_key(public_exponent=65537, key_size=2048) + + # Create JWK from the public key + public_jwk = jwk.construct(private_key.public_key(), algorithm="RS256") + + # Convert to dict and add kid + jwk_dict = public_jwk.to_dict() + jwk_dict["kid"] = key_id + jwk_dict["use"] = "sig" + jwk_dict["alg"] = "RS256" + + return jwk_dict + + +@respx.mock +def test_get_rsa_key_retry_on_failure(mock_settings: Settings): + """Test that get_rsa_key retries after clearing cache when key is not found.""" + token = jwt.encode({"some": "payload"}, "secret", algorithm="HS256") + unverified_header = {"kid": "missing_key"} + + other_key = generate_dummy_rsa_key("other_key") + missing_key = generate_dummy_rsa_key("missing_key") + + # Mock the cached response (first call) - key not found + cached_jwks = { + "keys": [other_key] + } + + # Mock the fresh response (second call after cache clear) - key found + fresh_jwks = { + "keys": [other_key, missing_key] + } + + jwks_url = f"https://{mock_settings.auth0_domain}/.well-known/jwks.json" + route = respx.get(jwks_url).mock( + side_effect=[ + Response(200, json=cached_jwks), + Response(200, json=fresh_jwks) + ] + ) + + # Clear the cache before the test to ensure clean state + from auth.validator import KEY_CACHE + KEY_CACHE.clear() + with patch("auth.validator.jwt.get_unverified_header", return_value=unverified_header): + # Call get_rsa_key + key = get_rsa_key(token, settings=mock_settings) + # Verify the key was found after retry + assert key is not None + assert isinstance(key, CryptographyRSAKey) + # Verify that the endpoint was called twice (cached + fresh) + assert route.call_count == 2 + + +@respx.mock +def test_get_rsa_key_no_retry_needed_when_key_found_first_time(mock_settings: Settings): + """Test that get_rsa_key doesn't retry when key is found on first attempt.""" + token = jwt.encode({"some": "payload"}, "secret", algorithm="HS256") + unverified_header = {"kid": "found_key"} + + # Generate test key using jose + found_key = generate_dummy_rsa_key("found_key") + + jwks_response = { + "keys": [found_key] + } + + jwks_url = f"https://{mock_settings.auth0_domain}/.well-known/jwks.json" + route = respx.get(jwks_url).mock(return_value=Response(200, json=jwks_response)) + + # Clear the cache before the test to ensure clean state + from auth.validator import KEY_CACHE + KEY_CACHE.clear() + + with patch("auth.validator.jwt.get_unverified_header", return_value=unverified_header): + # Call get_rsa_key + key = get_rsa_key(token, settings=mock_settings) + + # Verify key was found + assert key is not None + assert isinstance(key, CryptographyRSAKey) + + # Verify endpoint was called only once (no retry needed) + assert route.call_count == 1 + + def test_verify_jwt(mock_settings: Settings, mocker): """ Test we can verify a JWT based on issuer and audience. diff --git a/uv.lock b/uv.lock index 1a57c393..b2eef5eb 100644 --- a/uv.lock +++ b/uv.lock @@ -11,6 +11,7 @@ dependencies = [ { name = "apscheduler", extra = ["redis", "sqlalchemy"] }, { name = "authlib" }, { name = "boto3" }, + { name = "cachetools" }, { name = "fastapi", extra = ["standard"] }, { name = "httpx" }, { name = "itsdangerous" }, @@ -43,6 +44,7 @@ requires-dist = [ { name = "apscheduler", extras = ["redis", "sqlalchemy"], specifier = "~=3.11" }, { name = "authlib", specifier = ">=1.6.1" }, { name = "boto3", specifier = ">=1.34.0" }, + { name = "cachetools", specifier = "~=6.2" }, { name = "fastapi", extras = ["standard"], specifier = ">=0.115.12" }, { name = "freezegun", marker = "extra == 'dev'", specifier = ">=1.5.2" }, { name = "httpx", specifier = ">=0.28.1" }, @@ -162,6 +164,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/01/b6/dcd0fd188cc28d772e0df23a31ce50af4d358ef31bfee969dc5a033482a5/botocore-1.39.0-py3-none-any.whl", hash = "sha256:d8e72850d3450aeca355b654efb32c8370bf824c1945a61cad2395dc2688581e", size = 13753356, upload-time = "2025-06-30T19:24:49.416Z" }, ] +[[package]] +name = "cachetools" +version = "6.2.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/9d/61/e4fad8155db4a04bfb4734c7c8ff0882f078f24294d42798b3568eb63bff/cachetools-6.2.0.tar.gz", hash = "sha256:38b328c0889450f05f5e120f56ab68c8abaf424e1275522b138ffc93253f7e32", size = 30988, upload-time = "2025-08-25T18:57:30.924Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/6c/56/3124f61d37a7a4e7cc96afc5492c78ba0cb551151e530b54669ddd1436ef/cachetools-6.2.0-py3-none-any.whl", hash = "sha256:1c76a8960c0041fcc21097e357f882197c79da0dbff766e7317890a65d7d8ba6", size = 11276, upload-time = "2025-08-25T18:57:29.684Z" }, +] + [[package]] name = "certifi" version = "2025.6.15"