From dcdbe6f653c3ff4752c2b3ab324097d967e94cf9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tim=20M=C3=B6hlmann?= Date: Fri, 3 Mar 2023 18:44:33 +0200 Subject: [PATCH] impelement unit tests for the token Verifiers --- internal/testutil/token.go | 143 +++++++++++ pkg/client/rp/verifier.go | 10 +- pkg/client/rp/verifier_test.go | 343 ++++++++++++++++++++++++++ pkg/oidc/token.go | 1 + pkg/op/verifier_access_token.go | 2 - pkg/op/verifier_access_token_test.go | 129 ++++++++++ pkg/op/verifier_id_token_hint_test.go | 164 ++++++++++++ 7 files changed, 785 insertions(+), 7 deletions(-) create mode 100644 internal/testutil/token.go create mode 100644 pkg/client/rp/verifier_test.go create mode 100644 pkg/op/verifier_access_token_test.go create mode 100644 pkg/op/verifier_id_token_hint_test.go diff --git a/internal/testutil/token.go b/internal/testutil/token.go new file mode 100644 index 00000000..8e800266 --- /dev/null +++ b/internal/testutil/token.go @@ -0,0 +1,143 @@ +// Package testuril helps setting up required data for testing, +// such as tokens, claims and verifiers. +package testutil + +import ( + "context" + "crypto/rand" + "crypto/rsa" + "encoding/json" + "errors" + "time" + + "github.com/zitadel/oidc/v2/pkg/oidc" + "gopkg.in/square/go-jose.v2" +) + +const SignatureAlgorithm = jose.PS512 + +// KeySet implements oidc.Keys and +// additionally can create tokens and claims that can +// be validated by this KeySet. +type KeySet struct { + Private *rsa.PrivateKey + Public *rsa.PublicKey + + Signer jose.Signer +} + +func NewKeySet() *KeySet { + privateKey, err := rsa.GenerateKey(rand.Reader, 2048) + if err != nil { + panic(err) + } + signer, err := jose.NewSigner(jose.SigningKey{Algorithm: SignatureAlgorithm, Key: privateKey}, nil) + if err != nil { + panic(err) + } + return &KeySet{ + Private: privateKey, + Public: &privateKey.PublicKey, + Signer: signer, + } +} + +func (k *KeySet) signEncodeTokenClaims(claims any) string { + payload, err := json.Marshal(claims) + if err != nil { + panic(err) + } + object, err := k.Signer.Sign(payload) + if err != nil { + panic(err) + } + token, err := object.CompactSerialize() + if err != nil { + panic(err) + } + return token +} + +func claimsMap(claims any) map[string]any { + data, err := json.Marshal(claims) + if err != nil { + panic(err) + } + dst := make(map[string]any) + if err = json.Unmarshal(data, &dst); err != nil { + panic(err) + } + return dst +} + +// NewIDToken creates a new IDTokenClaims with passed data and returns a signed token and claims. +func (k *KeySet) NewIDToken(issuer, subject string, audience []string, expiration, authTime time.Time, nonce string, acr string, amr []string, clientID string, skew time.Duration, atHash string) (string, *oidc.IDTokenClaims) { + claims := oidc.NewIDTokenClaims(issuer, subject, audience, expiration, authTime, nonce, acr, amr, clientID, skew) + claims.AccessTokenHash = atHash + token := k.signEncodeTokenClaims(claims) + + // set this so that assertion in tests will work + claims.SignatureAlg = SignatureAlgorithm + claims.Claims = claimsMap(claims) + return token, claims +} + +// NewAcccessToken creates a new AccessTokenClaims with passed data and returns a signed token and claims. +func (k *KeySet) NewAccessToken(issuer, subject string, audience []string, expiration time.Time, jwtid, clientID string, skew time.Duration) (string, *oidc.AccessTokenClaims) { + claims := oidc.NewAccessTokenClaims(issuer, subject, audience, expiration, jwtid, clientID, skew) + token := k.signEncodeTokenClaims(claims) + + // set this so that assertion in tests will work + claims.SignatureAlg = SignatureAlgorithm + claims.Claims = claimsMap(claims) + return token, claims +} + +const InvalidSignatureToken = `eyJhbGciOiJQUzUxMiJ9.eyJpc3MiOiJsb2NhbC5jb20iLCJzdWIiOiJ0aW1AbG9jYWwuY29tIiwiYXVkIjpbInVuaXQiLCJ0ZXN0IiwiNTU1NjY2Il0sImV4cCI6MTY3Nzg0MDQzMSwiaWF0IjoxNjc3ODQwMzcwLCJhdXRoX3RpbWUiOjE2Nzc4NDAzMTAsIm5vbmNlIjoiMTIzNDUiLCJhY3IiOiJzb21ldGhpbmciLCJhbXIiOlsiZm9vIiwiYmFyIl0sImF6cCI6IjU1NTY2NiJ9.DtZmvVkuE4Hw48ijBMhRJbxEWCr_WEYuPQBMY73J9TP6MmfeNFkjVJf4nh4omjB9gVLnQ-xhEkNOe62FS5P0BB2VOxPuHZUj34dNspCgG3h98fGxyiMb5vlIYAHDF9T-w_LntlYItohv63MmdYR-hPpAqjXE7KOfErf-wUDGE9R3bfiQ4HpTdyFJB1nsToYrZ9lhP2mzjTCTs58ckZfQ28DFHn_lfHWpR4rJBgvLx7IH4rMrUayr09Ap-PxQLbv0lYMtmgG1z3JK8MXnuYR0UJdZnEIezOzUTlThhCXB-nvuAXYjYxZZTR0FtlgZUHhIpYK0V2abf_Q_Or36akNCUg` + +// These variables always result in a valid token +// for the same test run. +var ( + ValidIssuer = "local.com" + ValidSubject = "tim@local.com" + ValidAudience = []string{"unit", "test"} + ValidAuthTime = time.Now().Add(-time.Minute) // authtime is always 1 minute in the past + ValidExpiration = ValidAuthTime.Add(2 * time.Minute) // token is always 1 more minute available + ValidJWTID = "9876" + ValidNonce = "12345" + ValidACR = "something" + ValidAMR = []string{"foo", "bar"} + ValidClientID = "555666" + ValidSkew = time.Second +) + +// ValidIDToken returns a token and claims that are in the token. +// It uses the Valid* global variables and the token always passes +// verification within the same test run. +func (k *KeySet) ValidIDToken() (string, *oidc.IDTokenClaims) { + return k.NewIDToken(ValidIssuer, ValidSubject, ValidAudience, ValidExpiration, ValidAuthTime, ValidNonce, ValidACR, ValidAMR, ValidClientID, ValidSkew, "") +} + +// ValidAccessToken returns a token and claims that are in the token. +// It uses the Valid* global variables and the token always passes +// verification within the same test run. +func (k *KeySet) ValidAccessToken() (string, *oidc.AccessTokenClaims) { + return k.NewAccessToken(ValidIssuer, ValidSubject, ValidAudience, ValidExpiration, ValidJWTID, ValidClientID, ValidSkew) +} + +// VerifySignature implments op.KeySet. +func (k *KeySet) VerifySignature(ctx context.Context, jws *jose.JSONWebSignature) (payload []byte, err error) { + if ctx.Err() != nil { + return nil, err + } + + return jws.Verify(k.Public) +} + +// ACRVerify is a oidc.ACRVerifier func. +func ACRVerify(acr string) error { + if acr != ValidACR { + return errors.New("invalid acr") + } + return nil +} diff --git a/pkg/client/rp/verifier.go b/pkg/client/rp/verifier.go index 69e62496..e5c9f42d 100644 --- a/pkg/client/rp/verifier.go +++ b/pkg/client/rp/verifier.go @@ -21,17 +21,17 @@ type IDTokenVerifier interface { // VerifyTokens implement the Token Response Validation as defined in OIDC specification // https://openid.net/specs/openid-connect-core-1_0.html#TokenResponseValidation -func VerifyTokens[C oidc.IDClaims](ctx context.Context, accessToken, idTokenString string, v IDTokenVerifier) (claims C, err error) { +func VerifyTokens[C oidc.IDClaims](ctx context.Context, accessToken, idToken string, v IDTokenVerifier) (claims C, err error) { var nilClaims C - idToken, err := VerifyIDToken[C](ctx, idTokenString, v) + claims, err = VerifyIDToken[C](ctx, idToken, v) if err != nil { return nilClaims, err } - if err := VerifyAccessToken(accessToken, idToken.GetAccessTokenHash(), idToken.GetSignatureAlgorithm()); err != nil { + if err := VerifyAccessToken(accessToken, claims.GetAccessTokenHash(), claims.GetSignatureAlgorithm()); err != nil { return nilClaims, err } - return idToken, nil + return claims, nil } // VerifyIDToken validates the id token according to @@ -114,7 +114,7 @@ func NewIDTokenVerifier(issuer, clientID string, keySet oidc.KeySet, options ... issuer: issuer, clientID: clientID, keySet: keySet, - offset: 1 * time.Second, + offset: time.Second, nonce: func(_ context.Context) string { return "" }, diff --git a/pkg/client/rp/verifier_test.go b/pkg/client/rp/verifier_test.go new file mode 100644 index 00000000..41c79eab --- /dev/null +++ b/pkg/client/rp/verifier_test.go @@ -0,0 +1,343 @@ +package rp + +import ( + "context" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + tu "github.com/zitadel/oidc/v2/internal/testutil" + "github.com/zitadel/oidc/v2/pkg/oidc" + "gopkg.in/square/go-jose.v2" +) + +func TestVerifyTokens(t *testing.T) { + keySet := tu.NewKeySet() + verifier := &idTokenVerifier{ + issuer: tu.ValidIssuer, + maxAgeIAT: 2 * time.Minute, + offset: time.Second, + supportedSignAlgs: []string{string(jose.PS512)}, + keySet: keySet, + maxAge: 2 * time.Minute, + acr: tu.ACRVerify, + nonce: func(context.Context) string { return tu.ValidNonce }, + clientID: tu.ValidClientID, + } + accessToken, _ := keySet.ValidAccessToken() + atHash, err := oidc.ClaimHash(accessToken, tu.SignatureAlgorithm) + require.NoError(t, err) + + tests := []struct { + name string + accessToken string + idTokenClaims func() (string, *oidc.IDTokenClaims) + wantErr bool + }{ + { + name: "without access token", + idTokenClaims: keySet.ValidIDToken, + }, + { + name: "with access token", + accessToken: accessToken, + idTokenClaims: func() (string, *oidc.IDTokenClaims) { + return keySet.NewIDToken( + tu.ValidIssuer, tu.ValidSubject, tu.ValidAudience, + tu.ValidExpiration, tu.ValidAuthTime, tu.ValidNonce, + tu.ValidACR, tu.ValidAMR, tu.ValidClientID, tu.ValidSkew, atHash, + ) + }, + }, + { + name: "expired id token", + accessToken: accessToken, + idTokenClaims: func() (string, *oidc.IDTokenClaims) { + return keySet.NewIDToken( + tu.ValidIssuer, tu.ValidSubject, tu.ValidAudience, + tu.ValidExpiration.Add(-time.Hour), tu.ValidAuthTime, tu.ValidNonce, + tu.ValidACR, tu.ValidAMR, tu.ValidClientID, tu.ValidSkew, atHash, + ) + }, + wantErr: true, + }, + { + name: "wronf access token", + accessToken: accessToken, + idTokenClaims: func() (string, *oidc.IDTokenClaims) { + return keySet.NewIDToken( + tu.ValidIssuer, tu.ValidSubject, tu.ValidAudience, + tu.ValidExpiration, tu.ValidAuthTime, tu.ValidNonce, + tu.ValidACR, tu.ValidAMR, tu.ValidClientID, tu.ValidSkew, "~~~", + ) + }, + wantErr: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + idToken, want := tt.idTokenClaims() + got, err := VerifyTokens[*oidc.IDTokenClaims](context.Background(), tt.accessToken, idToken, verifier) + if tt.wantErr { + assert.Error(t, err) + assert.Nil(t, got) + return + } + require.NoError(t, err) + require.NotNil(t, got) + assert.Equal(t, got, want) + }) + } +} + +func TestVerifyIDToken(t *testing.T) { + keySet := tu.NewKeySet() + verifier := &idTokenVerifier{ + issuer: tu.ValidIssuer, + maxAgeIAT: 2 * time.Minute, + offset: time.Second, + supportedSignAlgs: []string{string(jose.PS512)}, + keySet: keySet, + maxAge: 2 * time.Minute, + acr: tu.ACRVerify, + nonce: func(context.Context) string { return tu.ValidNonce }, + } + + tests := []struct { + name string + clientID string + tokenClaims func() (string, *oidc.IDTokenClaims) + wantErr bool + }{ + { + name: "success", + clientID: tu.ValidClientID, + tokenClaims: keySet.ValidIDToken, + }, + { + name: "parse err", + clientID: tu.ValidClientID, + tokenClaims: func() (string, *oidc.IDTokenClaims) { return "~~~~", nil }, + wantErr: true, + }, + { + name: "invalid signature", + clientID: tu.ValidClientID, + tokenClaims: func() (string, *oidc.IDTokenClaims) { return tu.InvalidSignatureToken, nil }, + wantErr: true, + }, + { + name: "empty subject", + clientID: tu.ValidClientID, + tokenClaims: func() (string, *oidc.IDTokenClaims) { + return keySet.NewIDToken( + tu.ValidIssuer, "", tu.ValidAudience, + tu.ValidExpiration, tu.ValidAuthTime, tu.ValidNonce, + tu.ValidACR, tu.ValidAMR, tu.ValidClientID, tu.ValidSkew, "", + ) + }, + wantErr: true, + }, + { + name: "wrong issuer", + clientID: tu.ValidClientID, + tokenClaims: func() (string, *oidc.IDTokenClaims) { + return keySet.NewIDToken( + "foo", tu.ValidSubject, tu.ValidAudience, + tu.ValidExpiration, tu.ValidAuthTime, tu.ValidNonce, + tu.ValidACR, tu.ValidAMR, tu.ValidClientID, tu.ValidSkew, "", + ) + }, + wantErr: true, + }, + { + name: "wrong clientID", + clientID: "foo", + tokenClaims: keySet.ValidIDToken, + wantErr: true, + }, + { + name: "expired", + clientID: tu.ValidClientID, + tokenClaims: func() (string, *oidc.IDTokenClaims) { + return keySet.NewIDToken( + tu.ValidIssuer, tu.ValidSubject, tu.ValidAudience, + tu.ValidExpiration.Add(-time.Hour), tu.ValidAuthTime, tu.ValidNonce, + tu.ValidACR, tu.ValidAMR, tu.ValidClientID, tu.ValidSkew, "", + ) + }, + wantErr: true, + }, + { + name: "wrong IAT", + clientID: tu.ValidClientID, + tokenClaims: func() (string, *oidc.IDTokenClaims) { + return keySet.NewIDToken( + tu.ValidIssuer, tu.ValidSubject, tu.ValidAudience, + tu.ValidExpiration, tu.ValidAuthTime, tu.ValidNonce, + tu.ValidACR, tu.ValidAMR, tu.ValidClientID, -time.Hour, "", + ) + }, + wantErr: true, + }, + { + name: "wrong acr", + clientID: tu.ValidClientID, + tokenClaims: func() (string, *oidc.IDTokenClaims) { + return keySet.NewIDToken( + tu.ValidIssuer, tu.ValidSubject, tu.ValidAudience, + tu.ValidExpiration, tu.ValidAuthTime, tu.ValidNonce, + "else", tu.ValidAMR, tu.ValidClientID, tu.ValidSkew, "", + ) + }, + wantErr: true, + }, + { + name: "expired auth", + clientID: tu.ValidClientID, + tokenClaims: func() (string, *oidc.IDTokenClaims) { + return keySet.NewIDToken( + tu.ValidIssuer, tu.ValidSubject, tu.ValidAudience, + tu.ValidExpiration, tu.ValidAuthTime.Add(-time.Hour), tu.ValidNonce, + tu.ValidACR, tu.ValidAMR, tu.ValidClientID, tu.ValidSkew, "", + ) + }, + wantErr: true, + }, + { + name: "wrong nonce", + clientID: tu.ValidClientID, + tokenClaims: func() (string, *oidc.IDTokenClaims) { + return keySet.NewIDToken( + tu.ValidIssuer, tu.ValidSubject, tu.ValidAudience, + tu.ValidExpiration, tu.ValidAuthTime, "foo", + tu.ValidACR, tu.ValidAMR, tu.ValidClientID, tu.ValidSkew, "", + ) + }, + wantErr: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + token, want := tt.tokenClaims() + verifier.clientID = tt.clientID + got, err := VerifyIDToken[*oidc.IDTokenClaims](context.Background(), token, verifier) + if tt.wantErr { + assert.Error(t, err) + assert.Nil(t, got) + return + } + require.NoError(t, err) + require.NotNil(t, got) + assert.Equal(t, got, want) + }) + } +} + +func TestVerifyAccessToken(t *testing.T) { + keySet := tu.NewKeySet() + token, _ := keySet.ValidAccessToken() + hash, err := oidc.ClaimHash(token, tu.SignatureAlgorithm) + require.NoError(t, err) + + type args struct { + accessToken string + atHash string + sigAlgorithm jose.SignatureAlgorithm + } + tests := []struct { + name string + args args + wantErr bool + }{ + { + name: "empty hash", + }, + { + name: "success", + args: args{ + accessToken: token, + atHash: hash, + sigAlgorithm: tu.SignatureAlgorithm, + }, + }, + { + name: "invalid algorithm", + args: args{ + accessToken: token, + atHash: hash, + sigAlgorithm: "foo", + }, + wantErr: true, + }, + { + name: "mismatch", + args: args{ + accessToken: token, + atHash: "~~", + sigAlgorithm: tu.SignatureAlgorithm, + }, + wantErr: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := VerifyAccessToken(tt.args.accessToken, tt.args.atHash, tt.args.sigAlgorithm) + if tt.wantErr { + assert.Error(t, err) + return + } + require.NoError(t, err) + }) + } +} + +func TestNewIDTokenVerifier(t *testing.T) { + keySet := tu.NewKeySet() + type args struct { + issuer string + clientID string + keySet oidc.KeySet + options []VerifierOption + } + tests := []struct { + name string + args args + want IDTokenVerifier + }{ + { + name: "nil nonce", // otherwise assert.Equal will fail on the function + args: args{ + issuer: tu.ValidIssuer, + clientID: tu.ValidClientID, + keySet: keySet, + options: []VerifierOption{ + WithIssuedAtOffset(time.Minute), + //WithIssuedAtMaxAge(time.Hour), + WithNonce(nil), // otherwise assert.Equal will fail on the function + WithACRVerifier(nil), + WithAuthTimeMaxAge(2 * time.Hour), + WithSupportedSigningAlgorithms("ABC", "DEF"), + }, + }, + want: &idTokenVerifier{ + issuer: tu.ValidIssuer, + offset: time.Minute, + //maxAgeIAT: time.Hour, // Maybe BUG? + clientID: tu.ValidClientID, + keySet: keySet, + nonce: nil, + acr: nil, + maxAge: 2 * time.Hour, + supportedSignAlgs: []string{"ABC", "DEF"}, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := NewIDTokenVerifier(tt.args.issuer, tt.args.clientID, tt.args.keySet, tt.args.options...) + assert.Equal(t, tt.want, got) + }) + } +} diff --git a/pkg/oidc/token.go b/pkg/oidc/token.go index 61ec7a74..20303465 100644 --- a/pkg/oidc/token.go +++ b/pkg/oidc/token.go @@ -144,6 +144,7 @@ func NewIDTokenClaims(issuer, subject string, audience []string, expiration, aut AuthenticationContextClassReference: acr, AuthenticationMethodsReferences: amr, AuthorizedParty: clientID, + ClientID: clientID, }, } } diff --git a/pkg/op/verifier_access_token.go b/pkg/op/verifier_access_token.go index 76e0abaa..e5f82992 100644 --- a/pkg/op/verifier_access_token.go +++ b/pkg/op/verifier_access_token.go @@ -18,8 +18,6 @@ type accessTokenVerifier struct { maxAgeIAT time.Duration offset time.Duration supportedSignAlgs []string - maxAge time.Duration - acr oidc.ACRVerifier keySet oidc.KeySet } diff --git a/pkg/op/verifier_access_token_test.go b/pkg/op/verifier_access_token_test.go new file mode 100644 index 00000000..718de1c1 --- /dev/null +++ b/pkg/op/verifier_access_token_test.go @@ -0,0 +1,129 @@ +package op + +import ( + "context" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + tu "github.com/zitadel/oidc/v2/internal/testutil" + "github.com/zitadel/oidc/v2/pkg/oidc" + "gopkg.in/square/go-jose.v2" +) + +func TestNewAccessTokenVerifier(t *testing.T) { + keySet := tu.NewKeySet() + type args struct { + issuer string + keySet oidc.KeySet + opts []AccessTokenVerifierOpt + } + tests := []struct { + name string + args args + want AccessTokenVerifier + }{ + { + name: "simple", + args: args{ + issuer: tu.ValidIssuer, + keySet: keySet, + }, + want: &accessTokenVerifier{ + issuer: tu.ValidIssuer, + keySet: keySet, + }, + }, + { + name: "with signature algorithm", + args: args{ + issuer: tu.ValidIssuer, + keySet: keySet, + opts: []AccessTokenVerifierOpt{ + WithSupportedAccessTokenSigningAlgorithms("ABC", "DEF"), + }, + }, + want: &accessTokenVerifier{ + issuer: tu.ValidIssuer, + keySet: keySet, + supportedSignAlgs: []string{"ABC", "DEF"}, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := NewAccessTokenVerifier(tt.args.issuer, tt.args.keySet, tt.args.opts...) + assert.Equal(t, tt.want, got) + }) + } +} + +func TestVerifyAccessToken(t *testing.T) { + keySet := tu.NewKeySet() + verifier := &accessTokenVerifier{ + issuer: tu.ValidIssuer, + maxAgeIAT: 2 * time.Minute, + offset: time.Second, + supportedSignAlgs: []string{string(jose.PS512)}, + keySet: keySet, + } + + tests := []struct { + name string + tokenClaims func() (string, *oidc.AccessTokenClaims) + wantErr bool + }{ + { + name: "success", + tokenClaims: keySet.ValidAccessToken, + }, + { + name: "parse err", + tokenClaims: func() (string, *oidc.AccessTokenClaims) { return "~~~~", nil }, + wantErr: true, + }, + { + name: "invalid signature", + tokenClaims: func() (string, *oidc.AccessTokenClaims) { return tu.InvalidSignatureToken, nil }, + wantErr: true, + }, + { + name: "wrong issuer", + tokenClaims: func() (string, *oidc.AccessTokenClaims) { + return keySet.NewAccessToken( + "foo", tu.ValidSubject, tu.ValidAudience, + tu.ValidExpiration, tu.ValidJWTID, tu.ValidClientID, + tu.ValidSkew, + ) + }, + wantErr: true, + }, + { + name: "expired", + tokenClaims: func() (string, *oidc.AccessTokenClaims) { + return keySet.NewAccessToken( + tu.ValidIssuer, tu.ValidSubject, tu.ValidAudience, + tu.ValidExpiration.Add(-time.Hour), tu.ValidJWTID, tu.ValidClientID, + tu.ValidSkew, + ) + }, + wantErr: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + token, want := tt.tokenClaims() + + got, err := VerifyAccessToken[*oidc.AccessTokenClaims](context.Background(), token, verifier) + if tt.wantErr { + assert.Error(t, err) + assert.Nil(t, got) + return + } + require.NoError(t, err) + require.NotNil(t, got) + assert.Equal(t, got, want) + }) + } +} diff --git a/pkg/op/verifier_id_token_hint_test.go b/pkg/op/verifier_id_token_hint_test.go new file mode 100644 index 00000000..27fc0b93 --- /dev/null +++ b/pkg/op/verifier_id_token_hint_test.go @@ -0,0 +1,164 @@ +package op + +import ( + "context" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + tu "github.com/zitadel/oidc/v2/internal/testutil" + "github.com/zitadel/oidc/v2/pkg/oidc" + "gopkg.in/square/go-jose.v2" +) + +func TestNewIDTokenHintVerifier(t *testing.T) { + keySet := tu.NewKeySet() + type args struct { + issuer string + keySet oidc.KeySet + opts []IDTokenHintVerifierOpt + } + tests := []struct { + name string + args args + want IDTokenHintVerifier + }{ + { + name: "simple", + args: args{ + issuer: tu.ValidIssuer, + keySet: keySet, + }, + want: &idTokenHintVerifier{ + issuer: tu.ValidIssuer, + keySet: keySet, + }, + }, + { + name: "with signature algorithm", + args: args{ + issuer: tu.ValidIssuer, + keySet: keySet, + opts: []IDTokenHintVerifierOpt{ + WithSupportedIDTokenHintSigningAlgorithms("ABC", "DEF"), + }, + }, + want: &idTokenHintVerifier{ + issuer: tu.ValidIssuer, + keySet: keySet, + supportedSignAlgs: []string{"ABC", "DEF"}, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := NewIDTokenHintVerifier(tt.args.issuer, tt.args.keySet, tt.args.opts...) + assert.Equal(t, tt.want, got) + }) + } +} + +func TestVerifyIDTokenHint(t *testing.T) { + keySet := tu.NewKeySet() + verifier := &idTokenHintVerifier{ + issuer: tu.ValidIssuer, + maxAgeIAT: 2 * time.Minute, + offset: time.Second, + supportedSignAlgs: []string{string(jose.PS512)}, + maxAge: 2 * time.Minute, + acr: tu.ACRVerify, + keySet: keySet, + } + + tests := []struct { + name string + tokenClaims func() (string, *oidc.IDTokenClaims) + wantErr bool + }{ + { + name: "success", + tokenClaims: keySet.ValidIDToken, + }, + { + name: "parse err", + tokenClaims: func() (string, *oidc.IDTokenClaims) { return "~~~~", nil }, + wantErr: true, + }, + { + name: "invalid signature", + tokenClaims: func() (string, *oidc.IDTokenClaims) { return tu.InvalidSignatureToken, nil }, + wantErr: true, + }, + { + name: "wrong issuer", + tokenClaims: func() (string, *oidc.IDTokenClaims) { + return keySet.NewIDToken( + "foo", tu.ValidSubject, tu.ValidAudience, + tu.ValidExpiration, tu.ValidAuthTime, tu.ValidNonce, + tu.ValidACR, tu.ValidAMR, tu.ValidClientID, tu.ValidSkew, "", + ) + }, + wantErr: true, + }, + { + name: "expired", + tokenClaims: func() (string, *oidc.IDTokenClaims) { + return keySet.NewIDToken( + tu.ValidIssuer, tu.ValidSubject, tu.ValidAudience, + tu.ValidExpiration.Add(-time.Hour), tu.ValidAuthTime, tu.ValidNonce, + tu.ValidACR, tu.ValidAMR, tu.ValidClientID, tu.ValidSkew, "", + ) + }, + wantErr: true, + }, + { + name: "wrong IAT", + tokenClaims: func() (string, *oidc.IDTokenClaims) { + return keySet.NewIDToken( + tu.ValidIssuer, tu.ValidSubject, tu.ValidAudience, + tu.ValidExpiration, tu.ValidAuthTime, tu.ValidNonce, + tu.ValidACR, tu.ValidAMR, tu.ValidClientID, -time.Hour, "", + ) + }, + wantErr: true, + }, + { + name: "wrong acr", + tokenClaims: func() (string, *oidc.IDTokenClaims) { + return keySet.NewIDToken( + tu.ValidIssuer, tu.ValidSubject, tu.ValidAudience, + tu.ValidExpiration, tu.ValidAuthTime, tu.ValidNonce, + "else", tu.ValidAMR, tu.ValidClientID, tu.ValidSkew, "", + ) + }, + wantErr: true, + }, + { + name: "expired auth", + tokenClaims: func() (string, *oidc.IDTokenClaims) { + return keySet.NewIDToken( + tu.ValidIssuer, tu.ValidSubject, tu.ValidAudience, + tu.ValidExpiration, tu.ValidAuthTime.Add(-time.Hour), tu.ValidNonce, + tu.ValidACR, tu.ValidAMR, tu.ValidClientID, tu.ValidSkew, "", + ) + }, + wantErr: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + token, want := tt.tokenClaims() + + got, err := VerifyIDTokenHint[*oidc.IDTokenClaims](context.Background(), token, verifier) + if tt.wantErr { + assert.Error(t, err) + assert.Nil(t, got) + return + } + require.NoError(t, err) + require.NotNil(t, got) + assert.Equal(t, got, want) + }) + } +}