Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
Canonical reference for changes, improvements, and bugfixes for cap.

## Next

* feat (jwt): add optional callback to dynamically fetch key sets in validator ([PR #187](https://github.com/hashicorp/cap/pull/187))
* fix (saml): always validate response and assertion signatures when signatures are present ([PR #180](https://github.com/hashicorp/cap/pull/180))

## 0.12.0
Expand Down
75 changes: 68 additions & 7 deletions jwt/jwt.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,24 @@ import (
// for validating the "nbf" (Not Before) and "exp" (Expiration Time) claims.
const DefaultLeewaySeconds = 150

// KeySetSearcher is an optional callback function that can be used to
// implement dynamic KeySet lookup based on the key ID (kid) from the JWT header.
// If provided to the Validator, it will be called to locate the appropriate KeySet
// for signature verification instead of iterating through a static list of KeySets.
type KeySetSearcher func(ctx context.Context, keyID string) (KeySet, error)

// Validator validates JSON Web Tokens (JWT) by providing signature
// verification and claims set validation. Validator can contain either
// a single or multiple KeySets and will attempt to verify the JWT by iterating
// through the configured KeySets.
// verification and claims set validation. Validator can be configured with either:
// - One or more KeySets: The validator will attempt to verify the JWT by iterating
// through the configured KeySets until one succeeds.
// - A KeySetSearcher callback: The validator will extract the key ID from the JWT header
// and use the callback to locate the appropriate KeySet for signature verification.
//
// Use NewValidator to create a Validator with KeySets, or NewValidatorWithKeySetSearcher
// to create a Validator with a KeySet searcher callback.
type Validator struct {
keySets []KeySet
keySets []KeySet
keySearcher KeySetSearcher
}

// NewValidator returns a Validator that uses the given KeySet to verify JWT signatures.
Expand All @@ -43,6 +55,18 @@ func NewValidator(keySets ...KeySet) (*Validator, error) {
}, nil
}

// NewValidatorWithKeySetSearcher returns a Validator that uses a KeySetSearcher
// callback to dynamically locate KeySets based on the key ID from the JWT header.
func NewValidatorWithKeySetSearcher(keySetSearcher KeySetSearcher) (*Validator, error) {
if keySetSearcher == nil {
return nil, errors.New("keySetSearcher must not be nil")
}

return &Validator{
keySearcher: keySetSearcher,
}, nil
}

// Expected defines the expected claims values to assert when validating a JWT.
// For claims that involve validation of the JWT with respect to time, leeway
// fields are provided to account for potential clock skew.
Expand Down Expand Up @@ -129,12 +153,49 @@ func (v *Validator) validateAll(ctx context.Context, token string, expected Expe

// Ensure that the token is signed by at least one of the given key sets
var tokenVerified bool
for _, keySet := range v.keySets {
// First, verify the signature to ensure subsequent validation is against verified claims

if v.keySearcher != nil {
// Use the KeySetSearcher callback to dynamically locate the appropriate KeySet based on the JWT's kid header
var jws *jose.JSONWebSignature
jws, err = jose.ParseSigned(token)
if err != nil {
return nil, fmt.Errorf("error parsing token: %w", err)
}
if len(jws.Signatures) == 0 {
return nil, fmt.Errorf("token must be signed")
}
if len(jws.Signatures) > 1 {
return nil, fmt.Errorf("token with multiple signatures not supported")
}

// Extract the kid (key ID) from the JWS header to locate the appropriate KeySet
keyID := jws.Signatures[0].Header.KeyID
if keyID == "" {
return nil, fmt.Errorf("token missing kid header parameter")
}

var keySet KeySet
keySet, err = v.keySearcher(ctx, keyID)
if err != nil {
return nil, fmt.Errorf("error searching for key set with kid %s: %w", keyID, err)
}
if keySet == nil {
return nil, fmt.Errorf("no key set found with kid %s", keyID)
}

allClaims, err = keySet.VerifySignature(ctx, token)
if err == nil {
tokenVerified = true
break
}
} else {
// Ensure that the token is signed by at least one of the given key sets
for _, keySet := range v.keySets {
// First, verify the signature to ensure subsequent validation is against verified claims
allClaims, err = keySet.VerifySignature(ctx, token)
if err == nil {
tokenVerified = true
break
}
}
}

Expand Down
209 changes: 209 additions & 0 deletions jwt/jwt_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,14 @@ import (
"context"
"crypto/rand"
"crypto/rsa"
"encoding/json"
"errors"
"fmt"
"strings"
"testing"
"time"

"github.com/go-jose/go-jose/v3"
"github.com/go-jose/go-jose/v3/jwt"
"github.com/stretchr/testify/require"

Expand Down Expand Up @@ -641,6 +644,212 @@ func TestNewValidator(t *testing.T) {
}
}

// TestNewValidatorWithKeySetSearcher tests cases for creating a new Validator with a KeySetSearcher.
func TestNewValidatorWithKeySetSearcher(t *testing.T) {
type args struct {
searcher KeySetSearcher
}
tests := []struct {
name string
args args
wantErr bool
}{
{
name: "valid key searcher",
args: args{
searcher: func(ctx context.Context, keyID string) (KeySet, error) {
return nil, nil // Just for constructor validation
},
},
},
{
name: "nil key searcher",
args: args{
searcher: nil,
},
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := NewValidatorWithKeySetSearcher(tt.args.searcher)
if tt.wantErr {
require.Error(t, err)
return
}
require.NoError(t, err)
require.NotNil(t, got)
})
}
}

// TestValidator_WithKeySetSearcher tests the KeySetSearcher code path.
func TestValidator_WithKeySetSearcher(t *testing.T) {
tp := oidc.StartTestProvider(t)
tp.SetSigningKeys(priv, priv.Public(), oidc.RS256, testKeyID)

// Create the KeySet to be used to verify JWT signatures
keySet, err := NewJSONWebKeySet(context.Background(), tp.Addr()+wellKnownJWKS, tp.CACert())
require.NoError(t, err)

now := time.Now()
nowUnix := float64(now.Unix())
futureUnix := float64(now.Add(2 * jwt.DefaultLeeway).Unix())

t.Run("key searcher is called and signature verification succeeds", func(t *testing.T) {
claims := map[string]interface{}{
"iss": "https://example.com/",
"iat": nowUnix,
"exp": futureUnix,
}
token := oidc.TestSignJWT(t, priv, string(RS256), claims, []byte(testKeyID))

// Track whether key searcher is called
called := false
keySearcher := func(ctx context.Context, keyID string) (KeySet, error) {
called = true
require.Equal(t, testKeyID, keyID, "searcher should be called with kid from JWT header")
return keySet, nil
}

validator, err := NewValidatorWithKeySetSearcher(keySearcher)
require.NoError(t, err)

got, err := validator.Validate(context.Background(), token, Expected{
Issuer: "https://example.com/",
})

require.NoError(t, err)
require.NotNil(t, got)
require.True(t, called, "key searcher should have been called")
require.Equal(t, "https://example.com/", got["iss"])
})

t.Run("error from key searcher is propagated", func(t *testing.T) {
keySearcher := func(ctx context.Context, keyID string) (KeySet, error) {
return nil, errors.New("key set not found")
}

validator, err := NewValidatorWithKeySetSearcher(keySearcher)
require.NoError(t, err)

claims := map[string]interface{}{
"iss": "https://example.com/",
"iat": nowUnix,
"exp": futureUnix,
}
token := oidc.TestSignJWT(t, priv, string(RS256), claims, []byte(testKeyID))

_, err = validator.Validate(context.Background(), token, Expected{})

require.Error(t, err)
require.Contains(t, err.Error(), "key set not found")
})

t.Run("error when searcher returns nil KeySet", func(t *testing.T) {
keySearcher := func(ctx context.Context, keyID string) (KeySet, error) {
return nil, nil // Returns nil without error
}

validator, err := NewValidatorWithKeySetSearcher(keySearcher)
require.NoError(t, err)

claims := map[string]interface{}{
"iss": "https://example.com/",
"iat": nowUnix,
"exp": futureUnix,
}
token := oidc.TestSignJWT(t, priv, string(RS256), claims, []byte(testKeyID))

_, err = validator.Validate(context.Background(), token, Expected{})

require.Error(t, err)
require.Contains(t, err.Error(), "no key set found")
})

t.Run("error when JWT is missing kid header", func(t *testing.T) {
keySearcher := func(ctx context.Context, keyID string) (KeySet, error) {
return keySet, nil
}

validator, err := NewValidatorWithKeySetSearcher(keySearcher)
require.NoError(t, err)

claims := map[string]interface{}{
"iss": "https://example.com/",
"iat": nowUnix,
"exp": futureUnix,
}
// Create JWT without kid header by passing nil as keyID
token := oidc.TestSignJWT(t, priv, string(RS256), claims, nil)

_, err = validator.Validate(context.Background(), token, Expected{})

require.Error(t, err)
require.Contains(t, err.Error(), "token missing kid header parameter")
})

t.Run("error when JWT is malformed", func(t *testing.T) {
keySearcher := func(ctx context.Context, keyID string) (KeySet, error) {
return keySet, nil
}

validator, err := NewValidatorWithKeySetSearcher(keySearcher)
require.NoError(t, err)

// Use a malformed JWT token
malformedToken := "not.a.valid.jwt.token"

_, err = validator.Validate(context.Background(), malformedToken, Expected{})

require.Error(t, err)
require.Contains(t, err.Error(), "error parsing token")
})

t.Run("error when JWT has multiple signatures", func(t *testing.T) {
keySearcher := func(ctx context.Context, keyID string) (KeySet, error) {
return keySet, nil
}

validator, err := NewValidatorWithKeySetSearcher(keySearcher)
require.NoError(t, err)

// Create a valid JWT and then modify it to have multiple signatures
claims := map[string]interface{}{
"iss": "https://example.com/",
"iat": nowUnix,
"exp": futureUnix,
}
token := oidc.TestSignJWT(t, priv, string(RS256), claims, []byte(testKeyID))

// Parse the token and duplicate its signature
parsedJWS, err := jose.ParseSigned(token)
require.NoError(t, err)

// Manually create a JSON with duplicate signatures
var jwsMap map[string]interface{}
err = json.Unmarshal([]byte(parsedJWS.FullSerialize()), &jwsMap)
require.NoError(t, err)

sig := jwsMap["signature"].(string)
protected := jwsMap["protected"].(string)

// Create JSON with two identical signatures
multiSigJSON := fmt.Sprintf(`{
"payload": %q,
"signatures": [
{"protected": %q, "signature": %q},
{"protected": %q, "signature": %q}
]
}`, jwsMap["payload"], protected, sig, protected, sig)

_, err = validator.Validate(context.Background(), multiSigJSON, Expected{})

require.Error(t, err)
require.Contains(t, err.Error(), "token with multiple signatures not supported")
})
}

// TestValidator_MultipleKeySets_Validate_Valid_JWT tests cases where a JWT is expected to be valid where the
// validator is initialized with multiple KeySets.
func TestValidator_MultipleKeySets_Validate_Valid_JWT(t *testing.T) {
Expand Down
2 changes: 1 addition & 1 deletion oidc/testing.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ func TestSignJWT(t TestingT, key crypto.PrivateKey, alg string, claims interface

hdr := map[jose.HeaderKey]interface{}{}
if keyID != nil {
hdr["key_id"] = string(keyID)
hdr["kid"] = string(keyID)
}

sig, err := jose.NewSigner(
Expand Down
Loading