Skip to content

Commit

Permalink
jkws: refactor RemoteKeySet cache to a map
Browse files Browse the repository at this point in the history
this refactors RemoteKeySet to cache keys using a map so that keys can
be looked up by keyID in constant-time.
  • Loading branch information
mattbonnell committed Mar 21, 2021
1 parent 08563f6 commit 936b014
Showing 1 changed file with 35 additions and 21 deletions.
56 changes: 35 additions & 21 deletions oidc/jwks.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,14 +46,14 @@ type RemoteKeySet struct {
inflight *inflight

// A set of cached keys.
cachedKeys []jose.JSONWebKey
cachedKeys map[string]jose.JSONWebKey
}

// inflight is used to wait on some in-flight request from multiple goroutines.
type inflight struct {
doneCh chan struct{}

keys []jose.JSONWebKey
keys map[string]jose.JSONWebKey
err error
}

Expand All @@ -70,14 +70,14 @@ func (i *inflight) wait() <-chan struct{} {
// done can only be called by a single goroutine. It records the result of the
// inflight request and signals other goroutines that the result is safe to
// inspect.
func (i *inflight) done(keys []jose.JSONWebKey, err error) {
func (i *inflight) done(keys map[string]jose.JSONWebKey, err error) {
i.keys = keys
i.err = err
close(i.doneCh)
}

// result cannot be called until the wait() channel has returned a value.
func (i *inflight) result() ([]jose.JSONWebKey, error) {
func (i *inflight) result() (map[string]jose.JSONWebKey, error) {
return i.keys, i.err
}

Expand All @@ -102,43 +102,53 @@ func (r *RemoteKeySet) verify(ctx context.Context, jws *jose.JSONWebSignature) (
break
}

keys := r.keysFromCache()
for _, key := range keys {
if keyID == "" || key.KeyID == keyID {
if payload, err := jws.Verify(&key); err == nil {
return payload, nil
}
}
if payload, ok := r.verifyWithKey(keyID, jws); ok {
return payload, nil
}

// If the kid doesn't match, check for new keys from the remote. This is the
// strategy recommended by the spec.
//
// https://openid.net/specs/openid-connect-core-1_0.html#RotateSigKeys
keys, err := r.keysFromRemote(ctx)
_, err := r.keysFromRemote(ctx)
if err != nil {
return nil, fmt.Errorf("fetching keys %v", err)
}

for _, key := range keys {
if keyID == "" || key.KeyID == keyID {
if payload, ok := r.verifyWithKey(keyID, jws); ok {
return payload, nil
}

return nil, errors.New("failed to verify id token signature")
}

// verifyWithKey attempts to verify the jws using the key with keyID from the cache
// if keyID is the empty string, it tries each key in the cache
func (r *RemoteKeySet) verifyWithKey(keyID string, jws *jose.JSONWebSignature) (payload []byte, ok bool) {
if keyID == "" {
for _, key := range r.keysFromCache() {
if payload, err := jws.Verify(&key); err == nil {
return payload, nil
return payload, true
}
}
} else {
if key, ok := r.keysFromCache()[keyID]; ok {
if payload, err := jws.Verify(&key); err == nil {
return payload, true
}
}
}
return nil, errors.New("failed to verify id token signature")
return nil, false
}

func (r *RemoteKeySet) keysFromCache() (keys []jose.JSONWebKey) {
func (r *RemoteKeySet) keysFromCache() (keys map[string]jose.JSONWebKey) {
r.mu.Lock()
defer r.mu.Unlock()
return r.cachedKeys
}

// keysFromRemote syncs the key set from the remote set, records the values in the
// cache, and returns the key set.
func (r *RemoteKeySet) keysFromRemote(ctx context.Context) ([]jose.JSONWebKey, error) {
func (r *RemoteKeySet) keysFromRemote(ctx context.Context) (map[string]jose.JSONWebKey, error) {
// Need to lock to inspect the inflight request field.
r.mu.Lock()
// If there's not a current inflight request, create one.
Expand Down Expand Up @@ -178,7 +188,7 @@ func (r *RemoteKeySet) keysFromRemote(ctx context.Context) ([]jose.JSONWebKey, e
}
}

func (r *RemoteKeySet) updateKeys() ([]jose.JSONWebKey, error) {
func (r *RemoteKeySet) updateKeys() (map[string]jose.JSONWebKey, error) {
req, err := http.NewRequest("GET", r.jwksURL, nil)
if err != nil {
return nil, fmt.Errorf("oidc: can't create request: %v", err)
Expand All @@ -204,5 +214,9 @@ func (r *RemoteKeySet) updateKeys() ([]jose.JSONWebKey, error) {
if err != nil {
return nil, fmt.Errorf("oidc: failed to decode keys: %v %s", err, body)
}
return keySet.Keys, nil
keys := make(map[string]jose.JSONWebKey)
for _, key := range keySet.Keys {
keys[key.KeyID] = key
}
return keys, nil
}

0 comments on commit 936b014

Please sign in to comment.