Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions auth/user_permissions.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,13 @@
from schemas.user import SessionUser


def get_session_user(
async def get_session_user(
auth0_token: str = Depends(get_auth0_token), settings: Settings = Depends(get_settings)
) -> SessionUser:
"""
Get the current user's session data (access token).
"""
access_token = verify_jwt(auth0_token, settings=settings)
access_token = await verify_jwt(auth0_token, settings=settings)
return SessionUser(access_token=access_token)


Expand Down
50 changes: 32 additions & 18 deletions auth/validator.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import asyncio
import json
import logging
from threading import Lock
import weakref

import httpx
import jwt
Expand All @@ -16,12 +17,24 @@

KEY_CACHE_TIMEOUT = 6 * 60 * 60 # 6 hours
KEY_CACHE = TTLCache(maxsize=10, ttl=KEY_CACHE_TIMEOUT)
KEY_CACHE_LOCK = Lock()
# Lock the key cache per-loop
_KEY_CACHE_LOCKS: weakref.WeakKeyDictionary[asyncio.AbstractEventLoop, asyncio.Lock] = (
weakref.WeakKeyDictionary()
)


def verify_jwt(token: str, settings: Settings) -> AccessTokenPayload:
def get_key_cache_lock() -> asyncio.Lock:
loop = asyncio.get_running_loop()
lock = _KEY_CACHE_LOCKS.get(loop)
if lock is None:
lock = asyncio.Lock()
_KEY_CACHE_LOCKS[loop] = lock
return lock

Comment thread
marius-mather marked this conversation as resolved.

async def verify_jwt(token: str, settings: Settings) -> AccessTokenPayload:
try:
rsa_key = get_rsa_key(token, settings=settings)
rsa_key = await get_rsa_key(token, settings=settings)
except InvalidTokenError as e:
logger.warning(f"JWT rejected during RSA key lookup: {e}")
raise HTTPException(status_code=401, detail="Not authorized")
Expand Down Expand Up @@ -72,7 +85,7 @@ def verify_jwt(token: str, settings: Settings) -> AccessTokenPayload:
return AccessTokenPayload(**payload)


def _fetch_rsa_keys(auth0_domain: str) -> dict:
async def _fetch_rsa_keys(auth0_domain: str) -> dict:
"""
Try to get cached keys if possible, otherwise
refresh from Auth0
Expand All @@ -83,7 +96,7 @@ def _fetch_rsa_keys(auth0_domain: str) -> dict:

# Lock so we don't do the lookup multiple times
# if multiple requests come in while cache is expired
with KEY_CACHE_LOCK:
async with get_key_cache_lock():
# Check again: another request may have refreshed while
# this was waiting
cached = KEY_CACHE.get(cache_key, None)
Expand All @@ -92,14 +105,15 @@ def _fetch_rsa_keys(auth0_domain: str) -> dict:

try:
metadata_url = f"https://{auth0_domain}/.well-known/openid-configuration"
metadata_response = httpx.get(metadata_url)
metadata_response.raise_for_status()
metadata = metadata_response.json()

jwks_url = metadata["jwks_uri"]
response = httpx.get(jwks_url)
response.raise_for_status()
keys = response.json()
async with httpx.AsyncClient() as client:
metadata_response = await client.get(metadata_url)
metadata_response.raise_for_status()
metadata = metadata_response.json()

jwks_url = metadata["jwks_uri"]
response = await client.get(jwks_url)
response.raise_for_status()
keys = response.json()
except KeyError as exc:
logger.error(f"OIDC metadata from {metadata_url} did not include jwks_uri")
raise InvalidTokenError("Failed to fetch JWKS") from exc
Expand All @@ -113,8 +127,8 @@ def _fetch_rsa_keys(auth0_domain: str) -> dict:
return keys


def get_rsa_key(token: str, settings: Settings, retry_on_failure: bool = True):
jwks = _fetch_rsa_keys(settings.auth0_domain)
async def get_rsa_key(token: str, settings: Settings, retry_on_failure: bool = True):
jwks = await _fetch_rsa_keys(settings.auth0_domain)
unverified_header = jwt.get_unverified_header(token)
key_id = unverified_header.get("kid")
if not key_id:
Expand All @@ -126,9 +140,9 @@ def get_rsa_key(token: str, settings: Settings, retry_on_failure: bool = True):

# Retry without cache on failure (but only once, to prevent infinite retry)
if retry_on_failure:
with KEY_CACHE_LOCK:
async with get_key_cache_lock():
KEY_CACHE.clear()
return get_rsa_key(token, settings, retry_on_failure=False)
return await get_rsa_key(token, settings, retry_on_failure=False)

return None

Expand Down
2 changes: 1 addition & 1 deletion db/st_admin.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,7 +338,7 @@ async def handle_auth_callback(self, request: Request):
access_token = token.get("access_token")
if not access_token:
raise HTTPException(status_code=401, detail="Could not get access token.")
payload = verify_jwt(access_token, settings)
payload = await verify_jwt(access_token, settings)
if not payload:
raise HTTPException(status_code=401, detail="Could not verify JWT.")
if not payload.has_admin_role(settings):
Expand Down
91 changes: 47 additions & 44 deletions tests/auth/test_auth_validator.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
import asyncio
import json
import threading
import uuid
from dataclasses import dataclass
from datetime import datetime, timedelta
from typing import Optional
from unittest.mock import patch
from unittest.mock import AsyncMock, patch

import jwt
import pytest
Expand Down Expand Up @@ -129,14 +129,15 @@ def create_access_token(
)


def test_get_rsa_key_returns_key(mock_settings: Settings):
@pytest.mark.asyncio
async def test_get_rsa_key_returns_key(mock_settings: Settings):
token = jwt.encode({"some": "payload"}, TEST_HS256_SECRET, algorithm="HS256")
unverified_header = {"kid": "testkey"}
metadata_url = f"https://{mock_settings.auth0_domain}/.well-known/openid-configuration"
jwks_url = f"https://{mock_settings.auth0_domain}/.well-known/jwks.json"

Comment thread
marius-mather marked this conversation as resolved.
with patch("auth.validator.jwt.get_unverified_header", return_value=unverified_header), \
patch("auth.validator.httpx.get") as mock_get:
patch("auth.validator.httpx.AsyncClient.get", new_callable=AsyncMock) as mock_get:
metadata_response = Response(
200,
json={"jwks_uri": jwks_url},
Expand All @@ -153,10 +154,10 @@ def test_get_rsa_key_returns_key(mock_settings: Settings):
}, request=Request("GET", jwks_url))
mock_get.side_effect = [metadata_response, jwks_response]

key = get_rsa_key(token, settings=mock_settings)
key = await get_rsa_key(token, settings=mock_settings)
assert key is not None
assert mock_get.call_args_list[0].args == (metadata_url,)
assert mock_get.call_args_list[1].args == (jwks_url,)
assert mock_get.await_args_list[0].args == (metadata_url,)
assert mock_get.await_args_list[1].args == (jwks_url,)


def generate_dummy_rsa_key(key_id: str) -> dict:
Expand All @@ -173,8 +174,9 @@ def generate_dummy_rsa_key(key_id: str) -> dict:
return jwk_dict


@pytest.mark.asyncio
@respx.mock
def test_get_rsa_key_retry_on_failure(mock_settings: Settings):
async def test_get_rsa_key_retry_on_failure(mock_settings: Settings):
"""Test that get_rsa_key retries after clearing cache when key is not found."""
token = jwt.encode({"some": "payload"}, TEST_HS256_SECRET, algorithm="HS256")
unverified_header = {"kid": "missing_key"}
Expand Down Expand Up @@ -212,16 +214,17 @@ def test_get_rsa_key_retry_on_failure(mock_settings: Settings):
KEY_CACHE.clear()
with patch("auth.validator.jwt.get_unverified_header", return_value=unverified_header):
# Call get_rsa_key
key = get_rsa_key(token, settings=mock_settings)
key = await get_rsa_key(token, settings=mock_settings)
# Verify the key was found after retry
assert key is not None
# Verify that metadata and key lookups were retried after cache clear.
assert metadata_route.call_count == 2
assert jwks_route.call_count == 2


@pytest.mark.asyncio
@respx.mock
def test_get_rsa_key_no_retry_needed_when_key_found_first_time(mock_settings: Settings):
async def test_get_rsa_key_no_retry_needed_when_key_found_first_time(mock_settings: Settings):
"""Test that get_rsa_key doesn't retry when key is found on first attempt."""
token = jwt.encode({"some": "payload"}, TEST_HS256_SECRET, algorithm="HS256")
unverified_header = {"kid": "found_key"}
Expand All @@ -246,7 +249,7 @@ def test_get_rsa_key_no_retry_needed_when_key_found_first_time(mock_settings: Se

with patch("auth.validator.jwt.get_unverified_header", return_value=unverified_header):
# Call get_rsa_key
key = get_rsa_key(token, settings=mock_settings)
key = await get_rsa_key(token, settings=mock_settings)

# Verify key was found
assert key is not None
Expand Down Expand Up @@ -323,7 +326,8 @@ def protected_route(token: str = Depends(get_auth0_token)):
assert "detail" in body


def test_verify_jwt(mock_settings: Settings, mocker):
@pytest.mark.asyncio
async def test_verify_jwt(mock_settings: Settings, mocker):
"""
Test we can verify a JWT based on issuer and audience.
"""
Expand All @@ -334,12 +338,13 @@ def test_verify_jwt(mock_settings: Settings, mocker):
iss=f"https://{mock_settings.auth0_domain}/",
aud=f"https://{mock_settings.auth0_domain}/api/",
)
mocker.patch("auth.validator.get_rsa_key", return_value=token.public_key)
decoded = verify_jwt(token.access_token_str, settings=mock_settings)
mocker.patch("auth.validator.get_rsa_key", new=AsyncMock(return_value=token.public_key))
decoded = await verify_jwt(token.access_token_str, settings=mock_settings)
assert decoded.email == "user@example.com"


def test_verify_jwt_invalid_issuer(mock_settings: Settings, mocker):
@pytest.mark.asyncio
async def test_verify_jwt_invalid_issuer(mock_settings: Settings, mocker):
"""
Test invalid JWT issuer returns unauthorized.
"""
Expand All @@ -349,16 +354,17 @@ def test_verify_jwt_invalid_issuer(mock_settings: Settings, mocker):
iss="https://other.example.com/",
aud=f"https://{mock_settings.auth0_domain}/api/",
)
mocker.patch("auth.validator.get_rsa_key", return_value=token.public_key)
mocker.patch("auth.validator.get_rsa_key", new=AsyncMock(return_value=token.public_key))

with pytest.raises(HTTPException) as excinfo:
verify_jwt(token.access_token_str, settings=mock_settings)
await verify_jwt(token.access_token_str, settings=mock_settings)

assert excinfo.value.status_code == 401
assert excinfo.value.detail == "Not authorized"


def test_verify_jwt_custom_domain_issuer(mock_settings: Settings, mocker):
@pytest.mark.asyncio
async def test_verify_jwt_custom_domain_issuer(mock_settings: Settings, mocker):
"""
Check that our verify code also works with the auth0_issuer setting
"""
Expand All @@ -369,12 +375,13 @@ def test_verify_jwt_custom_domain_issuer(mock_settings: Settings, mocker):
iss=mock_settings.auth0_issuer,
aud=mock_settings.auth0_audience,
)
mocker.patch("auth.validator.get_rsa_key", return_value=token.public_key)
decoded = verify_jwt(token.access_token_str, settings=mock_settings)
mocker.patch("auth.validator.get_rsa_key", new=AsyncMock(return_value=token.public_key))
decoded = await verify_jwt(token.access_token_str, settings=mock_settings)
assert decoded.email == "user@example.com"


def test_verify_jwt_missing_kid_returns_unauthorized(mock_settings: Settings, mocker):
@pytest.mark.asyncio
async def test_verify_jwt_missing_kid_returns_unauthorized(mock_settings: Settings, mocker):
"""
Test a token with no kid header returns 401 instead of crashing.
"""
Expand All @@ -387,7 +394,7 @@ def test_verify_jwt_missing_kid_returns_unauthorized(mock_settings: Settings, mo
mocker.patch("auth.validator.jwt.get_unverified_header", return_value={"alg": "RS256"})

with pytest.raises(HTTPException) as excinfo:
verify_jwt(token.access_token_str, settings=mock_settings)
await verify_jwt(token.access_token_str, settings=mock_settings)

assert excinfo.value.status_code == 401
assert excinfo.value.detail == "Not authorized"
Expand Down Expand Up @@ -579,8 +586,9 @@ def test_verify_action_token_missing_exp(mock_settings: Settings):
assert excinfo.value.detail == "invalid session_token"


@pytest.mark.asyncio
@respx.mock
def test_fetch_rsa_keys_only_refreshes_once_when_cache_is_expired(mock_settings: Settings):
async def test_fetch_rsa_keys_only_refreshes_once_when_cache_is_expired(mock_settings: Settings):
"""
Demonstrate that concurrent requests only trigger one JWKS refresh.
"""
Expand All @@ -590,43 +598,38 @@ def test_fetch_rsa_keys_only_refreshes_once_when_cache_is_expired(mock_settings:
jwks_url = f"https://{mock_settings.auth0_domain}/.well-known/jwks.json"
jwks_response = {"keys": [generate_dummy_rsa_key("test-key")]}

start_gate = threading.Barrier(5)
first_request_started = threading.Event()
release_refresh = threading.Event()
start_gate = asyncio.Event()
first_request_started = asyncio.Event()
release_refresh = asyncio.Event()
call_count = 0
call_count_lock = threading.Lock()
call_count_lock = asyncio.Lock()

def slow_jwks_response(*args, **kwargs):
async def slow_jwks_response(*args, **kwargs):
nonlocal call_count
with call_count_lock:
async with call_count_lock:
call_count += 1
first_request_started.set()
release_refresh.wait()
await asyncio.wait_for(release_refresh.wait(), timeout=10)
return Response(200, json=jwks_response)

metadata_route = respx.get(metadata_url).mock(
return_value=Response(200, json={"jwks_uri": jwks_url})
)
jwks_route = respx.get(jwks_url).mock(side_effect=slow_jwks_response)

results = []
results_lock = threading.Lock()
results: list[dict] = []

def worker():
start_gate.wait(timeout=10)
result = _fetch_rsa_keys(mock_settings.auth0_domain)
with results_lock:
results.append(result)
async def worker():
await asyncio.wait_for(start_gate.wait(), timeout=10)
result = await _fetch_rsa_keys(mock_settings.auth0_domain)
results.append(result)

threads = [threading.Thread(target=worker) for _ in range(5)]
for thread in threads:
thread.start()
tasks = [asyncio.create_task(worker()) for _ in range(5)]
start_gate.set()

first_request_started.wait()
await asyncio.wait_for(first_request_started.wait(), timeout=10)
release_refresh.set()

for thread in threads:
thread.join(timeout=10)
await asyncio.gather(*tasks)
Comment thread
marius-mather marked this conversation as resolved.

# Check all calls went through
assert len(results) == 5
Expand Down
6 changes: 3 additions & 3 deletions tests/db/test_db_admin.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ async def test_auth_callback_invalid_jwt_raises(mocker, mock_settings, mock_requ
oauth = Mock()
oauth.create_client.return_value = AsyncMock(authorize_access_token=AsyncMock(return_value={"access_token": "token"}))
mocker.patch("db.st_admin.setup_oauth", return_value=oauth)
mocker.patch("db.st_admin.verify_jwt", return_value=None)
mocker.patch("db.st_admin.verify_jwt", new=AsyncMock(return_value=None))

from db.st_admin import Auth0AuthProvider
provider = Auth0AuthProvider()
Expand All @@ -134,7 +134,7 @@ async def test_auth_callback_missing_admin_role_raises(mocker, mock_settings, mo

payload = Mock()
payload.has_admin_role.return_value = False
mocker.patch("db.st_admin.verify_jwt", return_value=payload)
mocker.patch("db.st_admin.verify_jwt", new=AsyncMock(return_value=payload))

from db.st_admin import Auth0AuthProvider
provider = Auth0AuthProvider()
Expand All @@ -160,7 +160,7 @@ async def test_auth_callback_success_sets_session_and_redirects(mocker, mock_set

payload = Mock()
payload.has_admin_role.return_value = True
mocker.patch("db.st_admin.verify_jwt", return_value=payload)
mocker.patch("db.st_admin.verify_jwt", new=AsyncMock(return_value=payload))

# emulate ?next=/db-admin/
mock_request.query_params = {"next": "/db-admin/"}
Expand Down
Loading
Loading