diff --git a/auth/user_permissions.py b/auth/user_permissions.py index 5939b214..8ac6a6e6 100644 --- a/auth/user_permissions.py +++ b/auth/user_permissions.py @@ -13,13 +13,13 @@ from schemas.user import SessionUser -def get_session_user( +async def get_session_user( auth0_token: str = Depends(get_auth0_token), settings: Settings = Depends(get_settings) ) -> SessionUser: """ Get the current user's session data (access token). """ - access_token = verify_jwt(auth0_token, settings=settings) + access_token = await verify_jwt(auth0_token, settings=settings) return SessionUser(access_token=access_token) diff --git a/auth/validator.py b/auth/validator.py index f4d9280e..a0491fe4 100644 --- a/auth/validator.py +++ b/auth/validator.py @@ -1,6 +1,7 @@ +import asyncio import json import logging -from threading import Lock +import weakref import httpx import jwt @@ -16,12 +17,24 @@ KEY_CACHE_TIMEOUT = 6 * 60 * 60 # 6 hours KEY_CACHE = TTLCache(maxsize=10, ttl=KEY_CACHE_TIMEOUT) -KEY_CACHE_LOCK = Lock() +# Lock the key cache per-loop +_KEY_CACHE_LOCKS: weakref.WeakKeyDictionary[asyncio.AbstractEventLoop, asyncio.Lock] = ( + weakref.WeakKeyDictionary() +) -def verify_jwt(token: str, settings: Settings) -> AccessTokenPayload: +def get_key_cache_lock() -> asyncio.Lock: + loop = asyncio.get_running_loop() + lock = _KEY_CACHE_LOCKS.get(loop) + if lock is None: + lock = asyncio.Lock() + _KEY_CACHE_LOCKS[loop] = lock + return lock + + +async def verify_jwt(token: str, settings: Settings) -> AccessTokenPayload: try: - rsa_key = get_rsa_key(token, settings=settings) + rsa_key = await get_rsa_key(token, settings=settings) except InvalidTokenError as e: logger.warning(f"JWT rejected during RSA key lookup: {e}") raise HTTPException(status_code=401, detail="Not authorized") @@ -72,7 +85,7 @@ def verify_jwt(token: str, settings: Settings) -> AccessTokenPayload: return AccessTokenPayload(**payload) -def _fetch_rsa_keys(auth0_domain: str) -> dict: +async def _fetch_rsa_keys(auth0_domain: str) -> dict: """ Try to get cached keys if possible, otherwise refresh from Auth0 @@ -83,7 +96,7 @@ def _fetch_rsa_keys(auth0_domain: str) -> dict: # Lock so we don't do the lookup multiple times # if multiple requests come in while cache is expired - with KEY_CACHE_LOCK: + async with get_key_cache_lock(): # Check again: another request may have refreshed while # this was waiting cached = KEY_CACHE.get(cache_key, None) @@ -92,14 +105,15 @@ def _fetch_rsa_keys(auth0_domain: str) -> dict: try: metadata_url = f"https://{auth0_domain}/.well-known/openid-configuration" - metadata_response = httpx.get(metadata_url) - metadata_response.raise_for_status() - metadata = metadata_response.json() - - jwks_url = metadata["jwks_uri"] - response = httpx.get(jwks_url) - response.raise_for_status() - keys = response.json() + async with httpx.AsyncClient() as client: + metadata_response = await client.get(metadata_url) + metadata_response.raise_for_status() + metadata = metadata_response.json() + + jwks_url = metadata["jwks_uri"] + response = await client.get(jwks_url) + response.raise_for_status() + keys = response.json() except KeyError as exc: logger.error(f"OIDC metadata from {metadata_url} did not include jwks_uri") raise InvalidTokenError("Failed to fetch JWKS") from exc @@ -113,8 +127,8 @@ def _fetch_rsa_keys(auth0_domain: str) -> dict: return keys -def get_rsa_key(token: str, settings: Settings, retry_on_failure: bool = True): - jwks = _fetch_rsa_keys(settings.auth0_domain) +async def get_rsa_key(token: str, settings: Settings, retry_on_failure: bool = True): + jwks = await _fetch_rsa_keys(settings.auth0_domain) unverified_header = jwt.get_unverified_header(token) key_id = unverified_header.get("kid") if not key_id: @@ -126,9 +140,9 @@ def get_rsa_key(token: str, settings: Settings, retry_on_failure: bool = True): # Retry without cache on failure (but only once, to prevent infinite retry) if retry_on_failure: - with KEY_CACHE_LOCK: + async with get_key_cache_lock(): KEY_CACHE.clear() - return get_rsa_key(token, settings, retry_on_failure=False) + return await get_rsa_key(token, settings, retry_on_failure=False) return None diff --git a/db/st_admin.py b/db/st_admin.py index 41652dac..07aacb36 100644 --- a/db/st_admin.py +++ b/db/st_admin.py @@ -338,7 +338,7 @@ async def handle_auth_callback(self, request: Request): access_token = token.get("access_token") if not access_token: raise HTTPException(status_code=401, detail="Could not get access token.") - payload = verify_jwt(access_token, settings) + payload = await verify_jwt(access_token, settings) if not payload: raise HTTPException(status_code=401, detail="Could not verify JWT.") if not payload.has_admin_role(settings): diff --git a/tests/auth/test_auth_validator.py b/tests/auth/test_auth_validator.py index e8d70498..302f0bc5 100644 --- a/tests/auth/test_auth_validator.py +++ b/tests/auth/test_auth_validator.py @@ -1,10 +1,10 @@ +import asyncio import json -import threading import uuid from dataclasses import dataclass from datetime import datetime, timedelta from typing import Optional -from unittest.mock import patch +from unittest.mock import AsyncMock, patch import jwt import pytest @@ -129,14 +129,15 @@ def create_access_token( ) -def test_get_rsa_key_returns_key(mock_settings: Settings): +@pytest.mark.asyncio +async def test_get_rsa_key_returns_key(mock_settings: Settings): token = jwt.encode({"some": "payload"}, TEST_HS256_SECRET, algorithm="HS256") unverified_header = {"kid": "testkey"} metadata_url = f"https://{mock_settings.auth0_domain}/.well-known/openid-configuration" jwks_url = f"https://{mock_settings.auth0_domain}/.well-known/jwks.json" with patch("auth.validator.jwt.get_unverified_header", return_value=unverified_header), \ - patch("auth.validator.httpx.get") as mock_get: + patch("auth.validator.httpx.AsyncClient.get", new_callable=AsyncMock) as mock_get: metadata_response = Response( 200, json={"jwks_uri": jwks_url}, @@ -153,10 +154,10 @@ def test_get_rsa_key_returns_key(mock_settings: Settings): }, request=Request("GET", jwks_url)) mock_get.side_effect = [metadata_response, jwks_response] - key = get_rsa_key(token, settings=mock_settings) + key = await get_rsa_key(token, settings=mock_settings) assert key is not None - assert mock_get.call_args_list[0].args == (metadata_url,) - assert mock_get.call_args_list[1].args == (jwks_url,) + assert mock_get.await_args_list[0].args == (metadata_url,) + assert mock_get.await_args_list[1].args == (jwks_url,) def generate_dummy_rsa_key(key_id: str) -> dict: @@ -173,8 +174,9 @@ def generate_dummy_rsa_key(key_id: str) -> dict: return jwk_dict +@pytest.mark.asyncio @respx.mock -def test_get_rsa_key_retry_on_failure(mock_settings: Settings): +async def test_get_rsa_key_retry_on_failure(mock_settings: Settings): """Test that get_rsa_key retries after clearing cache when key is not found.""" token = jwt.encode({"some": "payload"}, TEST_HS256_SECRET, algorithm="HS256") unverified_header = {"kid": "missing_key"} @@ -212,7 +214,7 @@ def test_get_rsa_key_retry_on_failure(mock_settings: Settings): KEY_CACHE.clear() with patch("auth.validator.jwt.get_unverified_header", return_value=unverified_header): # Call get_rsa_key - key = get_rsa_key(token, settings=mock_settings) + key = await get_rsa_key(token, settings=mock_settings) # Verify the key was found after retry assert key is not None # Verify that metadata and key lookups were retried after cache clear. @@ -220,8 +222,9 @@ def test_get_rsa_key_retry_on_failure(mock_settings: Settings): assert jwks_route.call_count == 2 +@pytest.mark.asyncio @respx.mock -def test_get_rsa_key_no_retry_needed_when_key_found_first_time(mock_settings: Settings): +async def test_get_rsa_key_no_retry_needed_when_key_found_first_time(mock_settings: Settings): """Test that get_rsa_key doesn't retry when key is found on first attempt.""" token = jwt.encode({"some": "payload"}, TEST_HS256_SECRET, algorithm="HS256") unverified_header = {"kid": "found_key"} @@ -246,7 +249,7 @@ def test_get_rsa_key_no_retry_needed_when_key_found_first_time(mock_settings: Se with patch("auth.validator.jwt.get_unverified_header", return_value=unverified_header): # Call get_rsa_key - key = get_rsa_key(token, settings=mock_settings) + key = await get_rsa_key(token, settings=mock_settings) # Verify key was found assert key is not None @@ -323,7 +326,8 @@ def protected_route(token: str = Depends(get_auth0_token)): assert "detail" in body -def test_verify_jwt(mock_settings: Settings, mocker): +@pytest.mark.asyncio +async def test_verify_jwt(mock_settings: Settings, mocker): """ Test we can verify a JWT based on issuer and audience. """ @@ -334,12 +338,13 @@ def test_verify_jwt(mock_settings: Settings, mocker): iss=f"https://{mock_settings.auth0_domain}/", aud=f"https://{mock_settings.auth0_domain}/api/", ) - mocker.patch("auth.validator.get_rsa_key", return_value=token.public_key) - decoded = verify_jwt(token.access_token_str, settings=mock_settings) + mocker.patch("auth.validator.get_rsa_key", new=AsyncMock(return_value=token.public_key)) + decoded = await verify_jwt(token.access_token_str, settings=mock_settings) assert decoded.email == "user@example.com" -def test_verify_jwt_invalid_issuer(mock_settings: Settings, mocker): +@pytest.mark.asyncio +async def test_verify_jwt_invalid_issuer(mock_settings: Settings, mocker): """ Test invalid JWT issuer returns unauthorized. """ @@ -349,16 +354,17 @@ def test_verify_jwt_invalid_issuer(mock_settings: Settings, mocker): iss="https://other.example.com/", aud=f"https://{mock_settings.auth0_domain}/api/", ) - mocker.patch("auth.validator.get_rsa_key", return_value=token.public_key) + mocker.patch("auth.validator.get_rsa_key", new=AsyncMock(return_value=token.public_key)) with pytest.raises(HTTPException) as excinfo: - verify_jwt(token.access_token_str, settings=mock_settings) + await verify_jwt(token.access_token_str, settings=mock_settings) assert excinfo.value.status_code == 401 assert excinfo.value.detail == "Not authorized" -def test_verify_jwt_custom_domain_issuer(mock_settings: Settings, mocker): +@pytest.mark.asyncio +async def test_verify_jwt_custom_domain_issuer(mock_settings: Settings, mocker): """ Check that our verify code also works with the auth0_issuer setting """ @@ -369,12 +375,13 @@ def test_verify_jwt_custom_domain_issuer(mock_settings: Settings, mocker): iss=mock_settings.auth0_issuer, aud=mock_settings.auth0_audience, ) - mocker.patch("auth.validator.get_rsa_key", return_value=token.public_key) - decoded = verify_jwt(token.access_token_str, settings=mock_settings) + mocker.patch("auth.validator.get_rsa_key", new=AsyncMock(return_value=token.public_key)) + decoded = await verify_jwt(token.access_token_str, settings=mock_settings) assert decoded.email == "user@example.com" -def test_verify_jwt_missing_kid_returns_unauthorized(mock_settings: Settings, mocker): +@pytest.mark.asyncio +async def test_verify_jwt_missing_kid_returns_unauthorized(mock_settings: Settings, mocker): """ Test a token with no kid header returns 401 instead of crashing. """ @@ -387,7 +394,7 @@ def test_verify_jwt_missing_kid_returns_unauthorized(mock_settings: Settings, mo mocker.patch("auth.validator.jwt.get_unverified_header", return_value={"alg": "RS256"}) with pytest.raises(HTTPException) as excinfo: - verify_jwt(token.access_token_str, settings=mock_settings) + await verify_jwt(token.access_token_str, settings=mock_settings) assert excinfo.value.status_code == 401 assert excinfo.value.detail == "Not authorized" @@ -579,8 +586,9 @@ def test_verify_action_token_missing_exp(mock_settings: Settings): assert excinfo.value.detail == "invalid session_token" +@pytest.mark.asyncio @respx.mock -def test_fetch_rsa_keys_only_refreshes_once_when_cache_is_expired(mock_settings: Settings): +async def test_fetch_rsa_keys_only_refreshes_once_when_cache_is_expired(mock_settings: Settings): """ Demonstrate that concurrent requests only trigger one JWKS refresh. """ @@ -590,18 +598,18 @@ def test_fetch_rsa_keys_only_refreshes_once_when_cache_is_expired(mock_settings: jwks_url = f"https://{mock_settings.auth0_domain}/.well-known/jwks.json" jwks_response = {"keys": [generate_dummy_rsa_key("test-key")]} - start_gate = threading.Barrier(5) - first_request_started = threading.Event() - release_refresh = threading.Event() + start_gate = asyncio.Event() + first_request_started = asyncio.Event() + release_refresh = asyncio.Event() call_count = 0 - call_count_lock = threading.Lock() + call_count_lock = asyncio.Lock() - def slow_jwks_response(*args, **kwargs): + async def slow_jwks_response(*args, **kwargs): nonlocal call_count - with call_count_lock: + async with call_count_lock: call_count += 1 first_request_started.set() - release_refresh.wait() + await asyncio.wait_for(release_refresh.wait(), timeout=10) return Response(200, json=jwks_response) metadata_route = respx.get(metadata_url).mock( @@ -609,24 +617,19 @@ def slow_jwks_response(*args, **kwargs): ) jwks_route = respx.get(jwks_url).mock(side_effect=slow_jwks_response) - results = [] - results_lock = threading.Lock() + results: list[dict] = [] - def worker(): - start_gate.wait(timeout=10) - result = _fetch_rsa_keys(mock_settings.auth0_domain) - with results_lock: - results.append(result) + async def worker(): + await asyncio.wait_for(start_gate.wait(), timeout=10) + result = await _fetch_rsa_keys(mock_settings.auth0_domain) + results.append(result) - threads = [threading.Thread(target=worker) for _ in range(5)] - for thread in threads: - thread.start() + tasks = [asyncio.create_task(worker()) for _ in range(5)] + start_gate.set() - first_request_started.wait() + await asyncio.wait_for(first_request_started.wait(), timeout=10) release_refresh.set() - - for thread in threads: - thread.join(timeout=10) + await asyncio.gather(*tasks) # Check all calls went through assert len(results) == 5 diff --git a/tests/db/test_db_admin.py b/tests/db/test_db_admin.py index f1cefeff..1a9b5d87 100644 --- a/tests/db/test_db_admin.py +++ b/tests/db/test_db_admin.py @@ -111,7 +111,7 @@ async def test_auth_callback_invalid_jwt_raises(mocker, mock_settings, mock_requ oauth = Mock() oauth.create_client.return_value = AsyncMock(authorize_access_token=AsyncMock(return_value={"access_token": "token"})) mocker.patch("db.st_admin.setup_oauth", return_value=oauth) - mocker.patch("db.st_admin.verify_jwt", return_value=None) + mocker.patch("db.st_admin.verify_jwt", new=AsyncMock(return_value=None)) from db.st_admin import Auth0AuthProvider provider = Auth0AuthProvider() @@ -134,7 +134,7 @@ async def test_auth_callback_missing_admin_role_raises(mocker, mock_settings, mo payload = Mock() payload.has_admin_role.return_value = False - mocker.patch("db.st_admin.verify_jwt", return_value=payload) + mocker.patch("db.st_admin.verify_jwt", new=AsyncMock(return_value=payload)) from db.st_admin import Auth0AuthProvider provider = Auth0AuthProvider() @@ -160,7 +160,7 @@ async def test_auth_callback_success_sets_session_and_redirects(mocker, mock_set payload = Mock() payload.has_admin_role.return_value = True - mocker.patch("db.st_admin.verify_jwt", return_value=payload) + mocker.patch("db.st_admin.verify_jwt", new=AsyncMock(return_value=payload)) # emulate ?next=/db-admin/ mock_request.query_params = {"next": "/db-admin/"} diff --git a/tests/test_user.py b/tests/test_user.py index 3a0aa313..a53565e3 100644 --- a/tests/test_user.py +++ b/tests/test_user.py @@ -50,7 +50,8 @@ def mock_auth_token(mocker): sub="auth0|123456789", biocommons_roles=["acdc/indexd_admin"], ) - mocker.patch("auth.validator.verify_jwt", return_value=token) + mocker.patch("auth.validator.verify_jwt", new=AsyncMock(return_value=token)) + mocker.patch("auth.user_permissions.verify_jwt", new=AsyncMock(return_value=token)) mocker.patch("auth.management.get_management_token", return_value="mock_token") return token @@ -176,8 +177,8 @@ def test_check_is_admin_with_admin_role(test_client, mock_settings, mocker, test ) admin_user = SessionUserFactory.build(access_token=admin_token) - mocker.patch("auth.user_permissions.verify_jwt", return_value=admin_token) - mocker.patch("auth.user_permissions.get_session_user", return_value=admin_user) + mocker.patch("auth.user_permissions.verify_jwt", new=AsyncMock(return_value=admin_token)) + mocker.patch("auth.user_permissions.get_session_user", new=AsyncMock(return_value=admin_user)) response = test_client.get( "/me/is-general-admin", @@ -198,8 +199,8 @@ def test_check_is_admin_with_non_admin_role(test_client, mock_settings, mocker, ) user = SessionUserFactory.build(access_token=user_token) - mocker.patch("auth.user_permissions.verify_jwt", return_value=user_token) - mocker.patch("auth.user_permissions.get_session_user", return_value=user) + mocker.patch("auth.user_permissions.verify_jwt", new=AsyncMock(return_value=user_token)) + mocker.patch("auth.user_permissions.get_session_user", new=AsyncMock(return_value=user)) response = test_client.get( "/me/is-general-admin", @@ -223,8 +224,8 @@ def _act_as_user(mocker, db_user, roles: list[str] = None): """ access_token = AccessTokenPayloadFactory.build(sub=db_user.id, biocommons_roles=roles or []) auth0_user = SessionUserFactory.build(access_token=access_token) - mocker.patch("auth.user_permissions.verify_jwt", return_value=access_token) - mocker.patch("routers.user.get_session_user", return_value=auth0_user) + mocker.patch("auth.user_permissions.verify_jwt", new=AsyncMock(return_value=access_token)) + mocker.patch("routers.user.get_session_user", new=AsyncMock(return_value=auth0_user)) return auth0_user