diff --git a/parser_option.go b/parser_option.go index 88a780fb..3bf01369 100644 --- a/parser_option.go +++ b/parser_option.go @@ -66,15 +66,15 @@ func WithExpirationRequired() ParserOption { } } -// WithAudience configures the validator to require the specified audience in -// the `aud` claim. Validation will fail if the audience is not listed in the -// token or the `aud` claim is missing. +// WithAudience configures the validator to require ONE of the specified +// audiences to be present in the `aud` claim. Validation will fail if none of +// the audiences is listed in the token or the `aud` claim is missing. // // NOTE: While the `aud` claim is OPTIONAL in a JWT, the handling of it is // application-specific. Since this validation API is helping developers in // writing secure application, we decided to REQUIRE the existence of the claim, // if an audience is expected. -func WithAudience(aud string) ParserOption { +func WithAudience(aud ...string) ParserOption { return func(p *Parser) { p.validator.expectedAud = aud } diff --git a/validator.go b/validator.go index 008ecd87..bb43ccac 100644 --- a/validator.go +++ b/validator.go @@ -53,7 +53,7 @@ type Validator struct { // expectedAud contains the audience this token expects. Supplying an empty // string will disable aud checking. - expectedAud string + expectedAud []string // expectedIss contains the issuer this token expects. Supplying an empty // string will disable iss checking. @@ -120,7 +120,7 @@ func (v *Validator) Validate(claims Claims) error { } // If we have an expected audience, we also require the audience claim - if v.expectedAud != "" { + if len(v.expectedAud) > 0 { if err = v.verifyAudience(claims, v.expectedAud, true); err != nil { errs = append(errs, err) } @@ -226,7 +226,7 @@ func (v *Validator) verifyNotBefore(claims Claims, cmp time.Time, required bool) // // Additionally, if any error occurs while retrieving the claim, e.g., when its // the wrong type, an ErrTokenUnverifiable error will be returned. -func (v *Validator) verifyAudience(claims Claims, cmp string, required bool) error { +func (v *Validator) verifyAudience(claims Claims, cmp []string, required bool) error { aud, err := claims.GetAudience() if err != nil { return err @@ -241,10 +241,12 @@ func (v *Validator) verifyAudience(claims Claims, cmp string, required bool) err var stringClaims string for _, a := range aud { - if subtle.ConstantTimeCompare([]byte(a), []byte(cmp)) != 0 { - result = true + for _, c := range cmp { + if subtle.ConstantTimeCompare([]byte(a), []byte(c)) != 0 { + result = true + } + stringClaims = stringClaims + a } - stringClaims = stringClaims + a } // case where "" is sent in one or many aud claims diff --git a/validator_test.go b/validator_test.go index 08a6bd71..bdc0b4b4 100644 --- a/validator_test.go +++ b/validator_test.go @@ -25,7 +25,7 @@ func Test_Validator_Validate(t *testing.T) { leeway time.Duration timeFunc func() time.Time verifyIat bool - expectedAud string + expectedAud []string expectedIss string expectedSub string } @@ -259,3 +259,50 @@ func Test_Validator_verifyIssuedAt(t *testing.T) { }) } } + +func Test_Validator_verifyAudience(t *testing.T) { + type fields struct { + expectedAud []string + } + type args struct { + claims Claims + cmp []string + required bool + } + tests := []struct { + name string + fields fields + args args + wantErr error + }{ + { + name: "single value in aud claim", + fields: fields{expectedAud: []string{"me", "you"}}, + args: args{claims: MapClaims{"aud": "me"}, cmp: []string{"me"}}, + wantErr: nil, + }, + { + name: "multiple values in aud claim", + fields: fields{expectedAud: []string{"me"}}, + args: args{claims: MapClaims{"aud": []string{"me", "you"}}, cmp: []string{"me"}}, + wantErr: nil, + }, + { + name: "claims with invalid audience", + fields: fields{expectedAud: []string{"me"}}, + args: args{claims: MapClaims{"aud": "you"}, cmp: []string{"me"}}, + wantErr: ErrTokenInvalidAudience, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + v := &Validator{ + expectedAud: tt.fields.expectedAud, + } + err := v.verifyAudience(tt.args.claims, tt.args.cmp, tt.args.required) + if (err != nil) && !errors.Is(err, tt.wantErr) { + t.Errorf("validator.verifyAudience() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +}