Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 15 additions & 4 deletions auth/validator.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from functools import lru_cache
from typing import Annotated

import httpx
from cachetools import TTLCache
from fastapi import Depends, HTTPException, status
from fastapi.security import OAuth2PasswordBearer
from jose import jwk, jwt
Expand All @@ -13,6 +13,8 @@

oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")

KEY_CACHE = TTLCache(maxsize=10, ttl=30 * 60)


def verify_jwt(token: str, settings: Settings) -> AccessTokenPayload:
try:
Expand Down Expand Up @@ -50,21 +52,30 @@ def verify_jwt(token: str, settings: Settings) -> AccessTokenPayload:
return AccessTokenPayload(**payload)


@lru_cache(maxsize=100)
def _fetch_rsa_keys(auth0_domain: str) -> dict:
cache_key = f"jwks_{auth0_domain}"
if cache_key in KEY_CACHE:
return KEY_CACHE[cache_key]
jwks_url = f"https://{auth0_domain}/.well-known/jwks.json"
response = httpx.get(jwks_url)
return response.json()
keys = response.json()
KEY_CACHE[cache_key] = keys
return keys


def get_rsa_key(token: str, settings: Settings) -> jwk.RSAKey | None: # type: ignore
def get_rsa_key(token: str, settings: Settings, retry_on_failure: bool = True) -> jwk.RSAKey | None: # type: ignore
jwks = _fetch_rsa_keys(settings.auth0_domain)
unverified_header = jwt.get_unverified_header(token)

for key in jwks["keys"]:
if key["kid"] == unverified_header["kid"]:
return jwk.construct(key)

# Retry without cache on failure
if retry_on_failure:
KEY_CACHE.clear()
return get_rsa_key(token, settings, retry_on_failure=False)

return None


Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ dependencies = [
"itsdangerous>=2.2.0",
"apscheduler[redis,sqlalchemy]~=3.11",
"loguru>=0.7.3",
"cachetools~=6.2",
]

[project.optional-dependencies]
Expand Down
93 changes: 92 additions & 1 deletion tests/test_auth_validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from unittest.mock import patch

import pytest
import respx

# Tools from hazmat should only be used for testing!
from cryptography.hazmat.primitives.asymmetric import rsa
Expand All @@ -13,7 +14,8 @@
RSAPublicKey,
)
from fastapi import HTTPException
from jose import jwt
from httpx import Response
from jose import jwk, jwt
from jose.backends.cryptography_backend import CryptographyRSAKey

from auth.validator import get_rsa_key, verify_jwt
Expand Down Expand Up @@ -130,6 +132,95 @@ def test_get_rsa_key_returns_key(mock_settings: Settings):
assert isinstance(key, CryptographyRSAKey)


def generate_dummy_rsa_key(key_id: str) -> dict:
"""Generate a test RSA key in JWKS format using jose."""
# Generate RSA key pair
private_key = rsa.generate_private_key(public_exponent=65537, key_size=2048)

# Create JWK from the public key
public_jwk = jwk.construct(private_key.public_key(), algorithm="RS256")

# Convert to dict and add kid
jwk_dict = public_jwk.to_dict()
jwk_dict["kid"] = key_id
jwk_dict["use"] = "sig"
jwk_dict["alg"] = "RS256"

return jwk_dict


@respx.mock
def test_get_rsa_key_retry_on_failure(mock_settings: Settings):
"""Test that get_rsa_key retries after clearing cache when key is not found."""
token = jwt.encode({"some": "payload"}, "secret", algorithm="HS256")
unverified_header = {"kid": "missing_key"}

other_key = generate_dummy_rsa_key("other_key")
missing_key = generate_dummy_rsa_key("missing_key")

# Mock the cached response (first call) - key not found
cached_jwks = {
"keys": [other_key]
}

# Mock the fresh response (second call after cache clear) - key found
fresh_jwks = {
"keys": [other_key, missing_key]
}

jwks_url = f"https://{mock_settings.auth0_domain}/.well-known/jwks.json"
route = respx.get(jwks_url).mock(
side_effect=[
Response(200, json=cached_jwks),
Response(200, json=fresh_jwks)
]
)

# Clear the cache before the test to ensure clean state
from auth.validator import KEY_CACHE
KEY_CACHE.clear()
with patch("auth.validator.jwt.get_unverified_header", return_value=unverified_header):
# Call get_rsa_key
key = get_rsa_key(token, settings=mock_settings)
# Verify the key was found after retry
assert key is not None
assert isinstance(key, CryptographyRSAKey)
# Verify that the endpoint was called twice (cached + fresh)
assert route.call_count == 2


@respx.mock
def test_get_rsa_key_no_retry_needed_when_key_found_first_time(mock_settings: Settings):
"""Test that get_rsa_key doesn't retry when key is found on first attempt."""
token = jwt.encode({"some": "payload"}, "secret", algorithm="HS256")
unverified_header = {"kid": "found_key"}

# Generate test key using jose
found_key = generate_dummy_rsa_key("found_key")

jwks_response = {
"keys": [found_key]
}

jwks_url = f"https://{mock_settings.auth0_domain}/.well-known/jwks.json"
route = respx.get(jwks_url).mock(return_value=Response(200, json=jwks_response))

# Clear the cache before the test to ensure clean state
from auth.validator import KEY_CACHE
KEY_CACHE.clear()

with patch("auth.validator.jwt.get_unverified_header", return_value=unverified_header):
# Call get_rsa_key
key = get_rsa_key(token, settings=mock_settings)

# Verify key was found
assert key is not None
assert isinstance(key, CryptographyRSAKey)

# Verify endpoint was called only once (no retry needed)
assert route.call_count == 1


def test_verify_jwt(mock_settings: Settings, mocker):
"""
Test we can verify a JWT based on issuer and audience.
Expand Down
11 changes: 11 additions & 0 deletions uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.