mirror of
https://github.com/nesquena/hermes-webui.git
synced 2026-07-04 14:41:05 +00:00
696 lines
23 KiB
Python
696 lines
23 KiB
Python
import copy
|
|
import base64
|
|
import hashlib
|
|
import ipaddress
|
|
import json
|
|
import logging
|
|
import math
|
|
import os
|
|
import secrets
|
|
import socket
|
|
import threading
|
|
import time
|
|
import urllib.error
|
|
import urllib.parse
|
|
import urllib.request
|
|
from typing import Any
|
|
|
|
from cryptography.exceptions import InvalidSignature
|
|
from cryptography.hazmat.primitives import hashes
|
|
from cryptography.hazmat.primitives.asymmetric import ec, padding, rsa, utils
|
|
|
|
from api.config import get_config
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
_DEFAULT_SCOPES = ("openid", "profile", "email")
|
|
_PENDING_TTL_SECONDS = 600
|
|
_MAX_PENDING_FLOWS = 128
|
|
_CLOCK_SKEW_SECONDS = 60
|
|
_CACHE_TTL_SECONDS = 300
|
|
|
|
_pending_lock = threading.Lock()
|
|
_pending_flows: dict[str, dict[str, Any]] = {}
|
|
|
|
_discovery_lock = threading.Lock()
|
|
_discovery_cache: dict[str, tuple[float, dict[str, Any]]] = {}
|
|
|
|
_jwks_lock = threading.Lock()
|
|
_jwks_cache: dict[str, tuple[float, dict[str, Any]]] = {}
|
|
|
|
|
|
class _NoRedirect(urllib.request.HTTPRedirectHandler):
|
|
def redirect_request(self, *args, **kwargs):
|
|
return None
|
|
|
|
|
|
class OIDCConfigError(Exception):
|
|
pass
|
|
|
|
|
|
class OIDCAuthError(Exception):
|
|
def __init__(self, message: str, *, status_code: int = 401):
|
|
super().__init__(message)
|
|
self.status_code = status_code
|
|
|
|
|
|
def is_oidc_enabled() -> bool:
|
|
cfg = _resolve_oidc_config()
|
|
return bool(
|
|
cfg.get("issuer")
|
|
and cfg.get("client_id")
|
|
and cfg.get("allow_claim")
|
|
and cfg.get("allow_values")
|
|
)
|
|
|
|
|
|
def build_authorization_redirect(
|
|
request_base_url: str,
|
|
next_path: str | None = None,
|
|
) -> str:
|
|
cfg = _require_oidc_config()
|
|
discovery = _get_discovery_document(cfg["issuer"])
|
|
authorization_endpoint = str(discovery.get("authorization_endpoint") or "").strip()
|
|
if not authorization_endpoint:
|
|
raise OIDCConfigError("OIDC discovery document is missing authorization_endpoint")
|
|
redirect_uri = _resolve_redirect_uri(cfg, request_base_url)
|
|
state = secrets.token_urlsafe(24)
|
|
nonce = secrets.token_urlsafe(24)
|
|
verifier = secrets.token_urlsafe(48)
|
|
challenge = _b64u(hashlib.sha256(verifier.encode("ascii")).digest())
|
|
_store_pending_flow(
|
|
state,
|
|
{
|
|
"created_at": time.time(),
|
|
"nonce": nonce,
|
|
"code_verifier": verifier,
|
|
"next_path": _safe_next_path(next_path),
|
|
},
|
|
)
|
|
params = {
|
|
"response_type": "code",
|
|
"client_id": cfg["client_id"],
|
|
"redirect_uri": redirect_uri,
|
|
"scope": " ".join(cfg["scopes"]),
|
|
"state": state,
|
|
"nonce": nonce,
|
|
"code_challenge": challenge,
|
|
"code_challenge_method": "S256",
|
|
}
|
|
return authorization_endpoint + "?" + urllib.parse.urlencode(params)
|
|
|
|
|
|
def complete_authorization_code_flow(
|
|
request_base_url: str,
|
|
state: str,
|
|
code: str,
|
|
) -> dict[str, Any]:
|
|
cfg = _require_oidc_config()
|
|
pending = _consume_pending_flow(state)
|
|
if pending is None:
|
|
raise OIDCAuthError("Invalid OIDC state", status_code=401)
|
|
discovery = _get_discovery_document(cfg["issuer"])
|
|
discovery_issuer = str(discovery.get("issuer") or "").strip()
|
|
if discovery_issuer and discovery_issuer != cfg["issuer"]:
|
|
raise OIDCAuthError("OIDC discovery issuer did not match the configured issuer", status_code=502)
|
|
token_endpoint = str(discovery.get("token_endpoint") or "").strip()
|
|
if not token_endpoint:
|
|
raise OIDCConfigError("OIDC discovery document is missing token_endpoint")
|
|
redirect_uri = _resolve_redirect_uri(cfg, request_base_url)
|
|
token_response = _post_form_json(
|
|
token_endpoint,
|
|
{
|
|
"grant_type": "authorization_code",
|
|
"client_id": cfg["client_id"],
|
|
"code": code,
|
|
"code_verifier": pending["code_verifier"],
|
|
"redirect_uri": redirect_uri,
|
|
**({"client_secret": cfg["client_secret"]} if cfg.get("client_secret") else {}),
|
|
},
|
|
)
|
|
id_token = str(token_response.get("id_token") or "").strip()
|
|
if not id_token:
|
|
raise OIDCAuthError("OIDC token response did not include an id_token", status_code=502)
|
|
claims = _validate_id_token(
|
|
id_token,
|
|
client_id=cfg["client_id"],
|
|
issuer=cfg["issuer"],
|
|
nonce=pending["nonce"],
|
|
jwks_uri=str(discovery.get("jwks_uri") or "").strip(),
|
|
)
|
|
_enforce_allowlist(
|
|
claims,
|
|
allow_claim=cfg.get("allow_claim"),
|
|
allow_values=cfg.get("allow_values") or [],
|
|
)
|
|
return {
|
|
"next_path": pending["next_path"],
|
|
"subject": str(claims.get("sub") or ""),
|
|
"email": str(claims.get("email") or ""),
|
|
"claims": claims,
|
|
}
|
|
|
|
|
|
def _resolve_oidc_config() -> dict[str, Any]:
|
|
raw = {}
|
|
try:
|
|
cfg = get_config()
|
|
value = cfg.get("webui_oidc") if isinstance(cfg, dict) else None
|
|
if isinstance(value, dict):
|
|
raw.update(value)
|
|
except Exception:
|
|
logger.debug("Failed to read webui_oidc config", exc_info=True)
|
|
|
|
def pick(name: str, env_name: str) -> Any:
|
|
env_value = os.getenv(env_name)
|
|
return env_value if env_value is not None else raw.get(name)
|
|
|
|
scopes = _normalize_scopes(pick("scopes", "HERMES_WEBUI_OIDC_SCOPES"))
|
|
allow_values = _normalize_allow_values(
|
|
pick("allow_values", "HERMES_WEBUI_OIDC_ALLOW_VALUES")
|
|
)
|
|
return {
|
|
"issuer": str(pick("issuer", "HERMES_WEBUI_OIDC_ISSUER") or "").strip(),
|
|
"client_id": str(pick("client_id", "HERMES_WEBUI_OIDC_CLIENT_ID") or "").strip(),
|
|
"client_secret": str(pick("client_secret", "HERMES_WEBUI_OIDC_CLIENT_SECRET") or "").strip(),
|
|
"redirect_uri": str(pick("redirect_uri", "HERMES_WEBUI_OIDC_REDIRECT_URI") or "").strip(),
|
|
"scopes": scopes,
|
|
"allow_claim": str(pick("allow_claim", "HERMES_WEBUI_OIDC_ALLOW_CLAIM") or "").strip(),
|
|
"allow_values": allow_values,
|
|
}
|
|
|
|
|
|
def _require_oidc_config() -> dict[str, Any]:
|
|
cfg = _resolve_oidc_config()
|
|
if not cfg.get("issuer") or not cfg.get("client_id"):
|
|
raise OIDCConfigError("Native OIDC login is not configured")
|
|
if not cfg.get("allow_claim") or not cfg.get("allow_values"):
|
|
raise OIDCConfigError(
|
|
"Native OIDC login requires webui_oidc.allow_claim and allow_values"
|
|
)
|
|
return cfg
|
|
|
|
|
|
def _normalize_scopes(raw: Any) -> list[str]:
|
|
items = _normalize_text_list(raw)
|
|
if not items:
|
|
return list(_DEFAULT_SCOPES)
|
|
if "openid" not in items:
|
|
items.insert(0, "openid")
|
|
deduped = []
|
|
seen = set()
|
|
for item in items:
|
|
if item not in seen:
|
|
seen.add(item)
|
|
deduped.append(item)
|
|
return deduped
|
|
|
|
|
|
def _normalize_allow_values(raw: Any) -> list[str]:
|
|
return _normalize_text_list(raw)
|
|
|
|
|
|
def _normalize_text_list(raw: Any) -> list[str]:
|
|
if raw is None:
|
|
return []
|
|
if isinstance(raw, (list, tuple, set)):
|
|
values = [str(item).strip() for item in raw]
|
|
else:
|
|
text = str(raw).replace("\n", ",")
|
|
values = []
|
|
for comma_part in text.split(","):
|
|
values.extend(piece.strip() for piece in comma_part.split() if piece.strip())
|
|
return [value for value in values if value]
|
|
|
|
|
|
def _safe_next_path(raw_path: str | None) -> str:
|
|
path = str(raw_path or "").strip()
|
|
if not path:
|
|
return "/"
|
|
if path[0] != "/":
|
|
return "/"
|
|
if path[1:2] in {"/", "\\"}:
|
|
return "/"
|
|
if any(ord(ch) < 32 or ord(ch) == 127 or ch.isspace() for ch in path):
|
|
return "/"
|
|
return path
|
|
|
|
|
|
def _resolve_redirect_uri(cfg: dict[str, Any], request_base_url: str) -> str:
|
|
explicit = str(cfg.get("redirect_uri") or "").strip()
|
|
if explicit:
|
|
return explicit
|
|
return request_base_url.rstrip("/") + "/api/auth/oidc/callback"
|
|
|
|
|
|
def _store_pending_flow(state: str, payload: dict[str, Any]) -> None:
|
|
now = time.time()
|
|
with _pending_lock:
|
|
_prune_pending_flows(now)
|
|
_trim_pending_flows()
|
|
_pending_flows[state] = payload
|
|
|
|
|
|
def _consume_pending_flow(state: str) -> dict[str, Any] | None:
|
|
now = time.time()
|
|
with _pending_lock:
|
|
_prune_pending_flows(now)
|
|
payload = _pending_flows.pop(state, None)
|
|
return payload
|
|
|
|
|
|
def _prune_pending_flows(now: float) -> None:
|
|
expired = [
|
|
state
|
|
for state, payload in _pending_flows.items()
|
|
if now - float(payload.get("created_at") or 0) > _PENDING_TTL_SECONDS
|
|
]
|
|
for state in expired:
|
|
_pending_flows.pop(state, None)
|
|
|
|
|
|
def _trim_pending_flows() -> None:
|
|
overflow = len(_pending_flows) - _MAX_PENDING_FLOWS + 1
|
|
if overflow <= 0:
|
|
return
|
|
oldest = sorted(
|
|
_pending_flows,
|
|
key=lambda state: float(_pending_flows[state].get("created_at") or 0),
|
|
)
|
|
for state in oldest[:overflow]:
|
|
_pending_flows.pop(state, None)
|
|
|
|
|
|
def _get_discovery_document(issuer: str) -> dict[str, Any]:
|
|
discovery_url = _discovery_url_for_issuer(issuer)
|
|
cached = _cache_get(_discovery_lock, _discovery_cache, discovery_url)
|
|
if cached is not None:
|
|
return cached
|
|
data = _fetch_json(discovery_url)
|
|
if not isinstance(data, dict):
|
|
raise OIDCAuthError("OIDC discovery response was not a JSON object", status_code=502)
|
|
_cache_put(_discovery_lock, _discovery_cache, discovery_url, data)
|
|
return data
|
|
|
|
|
|
def _discovery_url_for_issuer(issuer: str) -> str:
|
|
if issuer.endswith("/.well-known/openid-configuration"):
|
|
return issuer
|
|
return issuer.rstrip("/") + "/.well-known/openid-configuration"
|
|
|
|
|
|
def _get_jwks_document(jwks_uri: str, *, force_refresh: bool = False) -> dict[str, Any]:
|
|
if not jwks_uri:
|
|
raise OIDCConfigError("OIDC discovery document is missing jwks_uri")
|
|
if force_refresh:
|
|
with _jwks_lock:
|
|
_jwks_cache.pop(jwks_uri, None)
|
|
else:
|
|
cached = _cache_get(_jwks_lock, _jwks_cache, jwks_uri)
|
|
if cached is not None:
|
|
return cached
|
|
data = _fetch_json(jwks_uri)
|
|
if not isinstance(data, dict):
|
|
raise OIDCAuthError("OIDC JWKS response was not a JSON object", status_code=502)
|
|
_cache_put(_jwks_lock, _jwks_cache, jwks_uri, data)
|
|
return data
|
|
|
|
|
|
def _cache_get(
|
|
lock: threading.Lock,
|
|
cache: dict[str, tuple[float, dict[str, Any]]],
|
|
key: str,
|
|
) -> dict[str, Any] | None:
|
|
now = time.time()
|
|
with lock:
|
|
entry = cache.get(key)
|
|
if entry is None:
|
|
return None
|
|
expires_at, value = entry
|
|
if expires_at <= now:
|
|
cache.pop(key, None)
|
|
return None
|
|
return copy.deepcopy(value)
|
|
|
|
|
|
def _cache_put(
|
|
lock: threading.Lock,
|
|
cache: dict[str, tuple[float, dict[str, Any]]],
|
|
key: str,
|
|
value: dict[str, Any],
|
|
) -> None:
|
|
with lock:
|
|
cache[key] = (time.time() + _CACHE_TTL_SECONDS, copy.deepcopy(value))
|
|
|
|
|
|
def _fetch_json(url: str) -> dict[str, Any]:
|
|
_validate_outbound_oidc_url(url)
|
|
req = urllib.request.Request(
|
|
url,
|
|
headers={"Accept": "application/json"},
|
|
)
|
|
try:
|
|
with _oidc_opener().open(req, timeout=10) as resp:
|
|
payload = json.loads(
|
|
resp.read().decode("utf-8"),
|
|
parse_constant=_reject_non_finite_json_constant,
|
|
)
|
|
except (urllib.error.HTTPError, urllib.error.URLError) as exc:
|
|
raise OIDCAuthError(f"Failed to reach OIDC endpoint: {url}", status_code=502) from exc
|
|
except (json.JSONDecodeError, ValueError) as exc:
|
|
raise OIDCAuthError(f"OIDC endpoint returned invalid JSON: {url}", status_code=502) from exc
|
|
return payload if isinstance(payload, dict) else {}
|
|
|
|
|
|
def _post_form_json(url: str, form_data: dict[str, Any]) -> dict[str, Any]:
|
|
_validate_outbound_oidc_url(url)
|
|
body = urllib.parse.urlencode(form_data).encode("utf-8")
|
|
req = urllib.request.Request(
|
|
url,
|
|
data=body,
|
|
headers={
|
|
"Accept": "application/json",
|
|
"Content-Type": "application/x-www-form-urlencoded",
|
|
},
|
|
method="POST",
|
|
)
|
|
try:
|
|
with _oidc_opener().open(req, timeout=10) as resp:
|
|
payload = json.loads(
|
|
resp.read().decode("utf-8"),
|
|
parse_constant=_reject_non_finite_json_constant,
|
|
)
|
|
except (urllib.error.HTTPError, urllib.error.URLError) as exc:
|
|
raise OIDCAuthError("Failed to exchange the OIDC authorization code", status_code=502) from exc
|
|
except (json.JSONDecodeError, ValueError) as exc:
|
|
raise OIDCAuthError("OIDC token endpoint returned invalid JSON", status_code=502) from exc
|
|
return payload if isinstance(payload, dict) else {}
|
|
|
|
|
|
def _oidc_opener() -> urllib.request.OpenerDirector:
|
|
return urllib.request.build_opener(_NoRedirect)
|
|
|
|
|
|
def _validate_outbound_oidc_url(url: str) -> None:
|
|
parsed = urllib.parse.urlparse(url)
|
|
if parsed.scheme != "https":
|
|
raise OIDCAuthError("OIDC endpoint URLs must use https", status_code=502)
|
|
if parsed.username or parsed.password:
|
|
raise OIDCAuthError("OIDC endpoint URLs must not contain credentials", status_code=502)
|
|
hostname = str(parsed.hostname or "").strip()
|
|
if not hostname:
|
|
raise OIDCAuthError("OIDC endpoint URL was missing a hostname", status_code=502)
|
|
if _is_disallowed_oidc_host(hostname):
|
|
raise OIDCAuthError(
|
|
"OIDC endpoint URLs must not target private or local addresses",
|
|
status_code=502,
|
|
)
|
|
|
|
|
|
def _is_disallowed_oidc_host(hostname: str) -> bool:
|
|
literal_ip = _parse_ip_address(hostname)
|
|
if literal_ip is not None:
|
|
return _is_disallowed_oidc_ip(literal_ip)
|
|
try:
|
|
infos = socket.getaddrinfo(hostname, 443, type=socket.SOCK_STREAM)
|
|
except socket.gaierror:
|
|
return False
|
|
for info in infos:
|
|
sockaddr = info[4]
|
|
address = _parse_ip_address(sockaddr[0] if sockaddr else "")
|
|
if address is not None and _is_disallowed_oidc_ip(address):
|
|
return True
|
|
return False
|
|
|
|
|
|
def _parse_ip_address(value: str):
|
|
try:
|
|
return ipaddress.ip_address(value)
|
|
except ValueError:
|
|
return None
|
|
|
|
|
|
def _is_disallowed_oidc_ip(address) -> bool:
|
|
candidate = getattr(address, "ipv4_mapped", None) or address
|
|
return (
|
|
candidate.is_loopback
|
|
or candidate.is_private
|
|
or candidate.is_link_local
|
|
or candidate.is_multicast
|
|
or candidate.is_unspecified
|
|
or candidate.is_reserved
|
|
)
|
|
|
|
|
|
def _reject_non_finite_json_constant(value: str):
|
|
raise ValueError(f"OIDC JSON contained unsupported constant: {value}")
|
|
|
|
|
|
def _validate_id_token(
|
|
token: str,
|
|
*,
|
|
client_id: str,
|
|
issuer: str,
|
|
nonce: str,
|
|
jwks_uri: str,
|
|
) -> dict[str, Any]:
|
|
header, claims, signed, signature = _parse_jwt(token)
|
|
alg = str(header.get("alg") or "").strip()
|
|
if not alg or alg == "none":
|
|
raise OIDCAuthError("OIDC id_token uses an unsupported signing algorithm")
|
|
jwks = _get_jwks_document(jwks_uri)
|
|
try:
|
|
public_key = _select_public_key(jwks, header)
|
|
except OIDCAuthError as exc:
|
|
if "did not contain the signing key" not in str(exc):
|
|
raise
|
|
jwks = _get_jwks_document(jwks_uri, force_refresh=True)
|
|
public_key = _select_public_key(jwks, header)
|
|
_verify_jwt_signature(public_key, alg, signed, signature)
|
|
_validate_registered_claims(claims, client_id=client_id, issuer=issuer, nonce=nonce)
|
|
if not str(claims.get("sub") or "").strip():
|
|
raise OIDCAuthError("OIDC id_token did not include a subject")
|
|
return claims
|
|
|
|
|
|
def _parse_jwt(token: str) -> tuple[dict[str, Any], dict[str, Any], bytes, bytes]:
|
|
parts = token.split(".")
|
|
if len(parts) != 3:
|
|
raise OIDCAuthError("OIDC id_token was not a JWT")
|
|
header_b64, payload_b64, signature_b64 = parts
|
|
try:
|
|
header = json.loads(
|
|
_b64u_decode(header_b64),
|
|
parse_constant=_reject_non_finite_json_constant,
|
|
)
|
|
claims = json.loads(
|
|
_b64u_decode(payload_b64),
|
|
parse_constant=_reject_non_finite_json_constant,
|
|
)
|
|
signature = _b64u_decode_bytes(signature_b64)
|
|
except Exception as exc:
|
|
raise OIDCAuthError("OIDC id_token could not be decoded") from exc
|
|
if not isinstance(header, dict) or not isinstance(claims, dict):
|
|
raise OIDCAuthError("OIDC id_token payload was malformed")
|
|
signed = f"{header_b64}.{payload_b64}".encode("ascii")
|
|
return header, claims, signed, signature
|
|
|
|
|
|
def _select_public_key(jwks: dict[str, Any], header: dict[str, Any]):
|
|
keys = jwks.get("keys")
|
|
if not isinstance(keys, list) or not keys:
|
|
raise OIDCAuthError("OIDC JWKS did not contain any signing keys", status_code=502)
|
|
kid = str(header.get("kid") or "").strip()
|
|
alg = str(header.get("alg") or "").strip()
|
|
matches = []
|
|
for key in keys:
|
|
if not isinstance(key, dict):
|
|
continue
|
|
if key.get("use") not in (None, "sig"):
|
|
continue
|
|
if kid and str(key.get("kid") or "").strip() != kid:
|
|
continue
|
|
if key.get("alg") not in (None, alg):
|
|
continue
|
|
if not _jwk_matches_alg_family(key, alg):
|
|
continue
|
|
matches.append(key)
|
|
if not matches:
|
|
raise OIDCAuthError("OIDC JWKS did not contain the signing key for this id_token", status_code=502)
|
|
return _jwk_to_public_key(matches[0])
|
|
|
|
|
|
def _jwk_matches_alg_family(jwk: dict[str, Any], alg: str) -> bool:
|
|
kty = str(jwk.get("kty") or "").strip()
|
|
if alg.startswith("RS"):
|
|
return kty == "RSA"
|
|
if alg.startswith("ES"):
|
|
return kty == "EC" and str(jwk.get("crv") or "").strip() == _ec_curve_for_alg(alg)
|
|
return True
|
|
|
|
|
|
def _ec_curve_for_alg(alg: str) -> str:
|
|
return {
|
|
"ES256": "P-256",
|
|
"ES384": "P-384",
|
|
"ES512": "P-521",
|
|
}.get(alg, "")
|
|
|
|
|
|
def _jwk_to_public_key(jwk: dict[str, Any]):
|
|
kty = str(jwk.get("kty") or "").strip()
|
|
if kty == "RSA":
|
|
n = _int_from_b64u(jwk.get("n"))
|
|
e = _int_from_b64u(jwk.get("e"))
|
|
return rsa.RSAPublicNumbers(e, n).public_key()
|
|
if kty == "EC":
|
|
crv = str(jwk.get("crv") or "").strip()
|
|
curve = {
|
|
"P-256": ec.SECP256R1(),
|
|
"P-384": ec.SECP384R1(),
|
|
"P-521": ec.SECP521R1(),
|
|
}.get(crv)
|
|
if curve is None:
|
|
raise OIDCAuthError(f"Unsupported OIDC EC curve: {crv}", status_code=502)
|
|
x = _int_from_b64u(jwk.get("x"))
|
|
y = _int_from_b64u(jwk.get("y"))
|
|
return ec.EllipticCurvePublicNumbers(x, y, curve).public_key()
|
|
raise OIDCAuthError(f"Unsupported OIDC key type: {kty}", status_code=502)
|
|
|
|
|
|
def _verify_jwt_signature(public_key, alg: str, signed: bytes, signature: bytes) -> None:
|
|
try:
|
|
if alg == "RS256":
|
|
public_key.verify(signature, signed, padding.PKCS1v15(), hashes.SHA256())
|
|
return
|
|
if alg == "RS384":
|
|
public_key.verify(signature, signed, padding.PKCS1v15(), hashes.SHA384())
|
|
return
|
|
if alg == "RS512":
|
|
public_key.verify(signature, signed, padding.PKCS1v15(), hashes.SHA512())
|
|
return
|
|
if alg == "ES256":
|
|
public_key.verify(_jose_ecdsa_signature_to_der(signature, 32), signed, ec.ECDSA(hashes.SHA256()))
|
|
return
|
|
if alg == "ES384":
|
|
public_key.verify(_jose_ecdsa_signature_to_der(signature, 48), signed, ec.ECDSA(hashes.SHA384()))
|
|
return
|
|
if alg == "ES512":
|
|
public_key.verify(_jose_ecdsa_signature_to_der(signature, 66), signed, ec.ECDSA(hashes.SHA512()))
|
|
return
|
|
except InvalidSignature as exc:
|
|
raise OIDCAuthError("OIDC id_token signature verification failed") from exc
|
|
raise OIDCAuthError(f"Unsupported OIDC signing algorithm: {alg}", status_code=502)
|
|
|
|
|
|
def _jose_ecdsa_signature_to_der(signature: bytes, part_size: int) -> bytes:
|
|
if len(signature) != part_size * 2:
|
|
raise OIDCAuthError("OIDC id_token ECDSA signature was malformed")
|
|
r = int.from_bytes(signature[:part_size], "big")
|
|
s = int.from_bytes(signature[part_size:], "big")
|
|
return utils.encode_dss_signature(r, s)
|
|
|
|
|
|
def _validate_registered_claims(
|
|
claims: dict[str, Any],
|
|
*,
|
|
client_id: str,
|
|
issuer: str,
|
|
nonce: str,
|
|
) -> None:
|
|
now = time.time()
|
|
if str(claims.get("iss") or "").strip() != issuer:
|
|
raise OIDCAuthError("OIDC id_token issuer did not match the configured issuer")
|
|
aud = claims.get("aud")
|
|
if isinstance(aud, list):
|
|
audiences = [str(item) for item in aud]
|
|
elif aud is None:
|
|
audiences = []
|
|
else:
|
|
audiences = [str(aud)]
|
|
if client_id not in audiences:
|
|
raise OIDCAuthError("OIDC id_token audience did not include this client")
|
|
if len(audiences) > 1 and str(claims.get("azp") or "").strip() not in {"", client_id}:
|
|
raise OIDCAuthError("OIDC id_token azp did not match this client")
|
|
exp = _coerce_numeric_claim(claims, "exp")
|
|
if exp is None or exp < now - _CLOCK_SKEW_SECONDS:
|
|
raise OIDCAuthError("OIDC id_token has expired")
|
|
nbf = _coerce_numeric_claim(claims, "nbf")
|
|
if nbf is not None and nbf > now + _CLOCK_SKEW_SECONDS:
|
|
raise OIDCAuthError("OIDC id_token is not valid yet")
|
|
iat = _coerce_numeric_claim(claims, "iat")
|
|
if iat is not None and iat > now + _CLOCK_SKEW_SECONDS:
|
|
raise OIDCAuthError("OIDC id_token has an invalid issued-at time")
|
|
if str(claims.get("nonce") or "").strip() != nonce:
|
|
raise OIDCAuthError("OIDC id_token nonce did not match the login request")
|
|
|
|
|
|
def _coerce_numeric_claim(claims: dict[str, Any], name: str) -> float | None:
|
|
value = claims.get(name)
|
|
if value is None:
|
|
return None
|
|
try:
|
|
number = float(value)
|
|
except (TypeError, ValueError) as exc:
|
|
raise OIDCAuthError(f"OIDC id_token claim {name} was not numeric") from exc
|
|
if not math.isfinite(number):
|
|
raise OIDCAuthError(f"OIDC id_token claim {name} was not numeric")
|
|
return number
|
|
|
|
|
|
def _enforce_allowlist(
|
|
claims: dict[str, Any],
|
|
*,
|
|
allow_claim: str,
|
|
allow_values: list[str],
|
|
) -> None:
|
|
if not allow_claim:
|
|
return
|
|
claim_value = _get_claim_path(claims, allow_claim)
|
|
if claim_value is None:
|
|
raise OIDCAuthError("OIDC identity is not allowed", status_code=403)
|
|
actual_values = _claim_values(claim_value)
|
|
if allow_values:
|
|
if not any(value in actual_values for value in allow_values):
|
|
raise OIDCAuthError("OIDC identity is not allowed", status_code=403)
|
|
return
|
|
if not actual_values:
|
|
raise OIDCAuthError("OIDC identity is not allowed", status_code=403)
|
|
|
|
|
|
def _get_claim_path(claims: dict[str, Any], dotted_key: str) -> Any:
|
|
current: Any = claims
|
|
for part in dotted_key.split("."):
|
|
if not isinstance(current, dict) or part not in current:
|
|
return None
|
|
current = current[part]
|
|
return current
|
|
|
|
|
|
def _claim_values(value: Any) -> set[str]:
|
|
if isinstance(value, (list, tuple, set)):
|
|
return {str(item) for item in value if str(item).strip()}
|
|
if isinstance(value, dict):
|
|
return {str(item) for item in value.values() if str(item).strip()}
|
|
text = str(value or "").strip()
|
|
return {text} if text else set()
|
|
|
|
|
|
def _b64u(data: bytes) -> str:
|
|
return base64.urlsafe_b64encode(data).rstrip(b"=").decode("ascii")
|
|
|
|
|
|
def _b64u_decode(data: str) -> str:
|
|
return _b64u_decode_bytes(data).decode("utf-8")
|
|
|
|
|
|
def _b64u_decode_bytes(data: str) -> bytes:
|
|
padded = data + "=" * ((4 - len(data) % 4) % 4)
|
|
return base64.urlsafe_b64decode(padded.encode("ascii"))
|
|
|
|
|
|
def _int_from_b64u(data: Any) -> int:
|
|
if not data:
|
|
raise OIDCAuthError("OIDC JWKS key was missing a required parameter", status_code=502)
|
|
return int.from_bytes(_b64u_decode_bytes(str(data)), "big")
|