diff --git a/pkg/oidc/regression_data/oidc.jwtProfileAssertion.json b/pkg/oidc/regression_data/oidc.JWTProfileAssertionClaims.json similarity index 100% rename from pkg/oidc/regression_data/oidc.jwtProfileAssertion.json rename to pkg/oidc/regression_data/oidc.JWTProfileAssertionClaims.json diff --git a/pkg/oidc/regression_test.go b/pkg/oidc/regression_test.go index 9ca77588..519ed50f 100644 --- a/pkg/oidc/regression_test.go +++ b/pkg/oidc/regression_test.go @@ -149,7 +149,7 @@ var ( "foo": "bar", }, } - jwtProfileAssertionRegressData = &jwtProfileAssertion{ + jwtProfileAssertionRegressData = &JWTProfileAssertionClaims{ PrivateKeyID: "8888", PrivateKey: []byte("qwerty"), Issuer: "zitadel", @@ -157,7 +157,7 @@ var ( Audience: Audience{"foo", "bar"}, Expiration: 12345, IssuedAt: 12000, - customClaims: map[string]interface{}{ + Claims: map[string]interface{}{ "foo": "bar", }, } diff --git a/pkg/oidc/token.go b/pkg/oidc/token.go index f6635407..3307176f 100644 --- a/pkg/oidc/token.go +++ b/pkg/oidc/token.go @@ -2,8 +2,7 @@ package oidc import ( "encoding/json" - "fmt" - "io/ioutil" + "os" "time" "golang.org/x/oauth2" @@ -180,19 +179,7 @@ type AccessTokenResponse struct { State string `json:"state,omitempty" schema:"state,omitempty"` } -type JWTProfileAssertionClaims interface { - GetKeyID() string - GetPrivateKey() []byte - GetIssuer() string - GetSubject() string - GetAudience() []string - GetExpiration() time.Time - GetIssuedAt() time.Time - SetCustomClaim(key string, value interface{}) - GetCustomClaim(key string) interface{} -} - -type jwtProfileAssertion struct { +type JWTProfileAssertionClaims struct { PrivateKeyID string `json:"-"` PrivateKey []byte `json:"-"` Issuer string `json:"iss"` @@ -201,91 +188,21 @@ type jwtProfileAssertion struct { Expiration Time `json:"exp"` IssuedAt Time `json:"iat"` - customClaims map[string]interface{} -} - -func (j *jwtProfileAssertion) MarshalJSON() ([]byte, error) { - type Alias jwtProfileAssertion - a := (*Alias)(j) - - b, err := json.Marshal(a) - if err != nil { - return nil, err - } - - if len(j.customClaims) == 0 { - return b, nil - } - - err = json.Unmarshal(b, &j.customClaims) - if err != nil { - return nil, fmt.Errorf("jws: invalid map of custom claims %v", j.customClaims) - } - - return json.Marshal(j.customClaims) -} - -func (j *jwtProfileAssertion) UnmarshalJSON(data []byte) error { - type Alias jwtProfileAssertion - a := (*Alias)(j) - - err := json.Unmarshal(data, a) - if err != nil { - return err - } - - err = json.Unmarshal(data, &j.customClaims) - if err != nil { - return err - } - - return nil -} - -func (j *jwtProfileAssertion) GetKeyID() string { - return j.PrivateKeyID -} - -func (j *jwtProfileAssertion) GetPrivateKey() []byte { - return j.PrivateKey + Claims map[string]interface{} `json:"-"` } -func (j *jwtProfileAssertion) SetCustomClaim(key string, value interface{}) { - if j.customClaims == nil { - j.customClaims = make(map[string]interface{}) - } - j.customClaims[key] = value -} - -func (j *jwtProfileAssertion) GetCustomClaim(key string) interface{} { - if j.customClaims == nil { - return nil - } - return j.customClaims[key] -} - -func (j *jwtProfileAssertion) GetIssuer() string { - return j.Issuer -} - -func (j *jwtProfileAssertion) GetSubject() string { - return j.Subject -} - -func (j *jwtProfileAssertion) GetAudience() []string { - return j.Audience -} +type jpaAlias JWTProfileAssertionClaims -func (j *jwtProfileAssertion) GetExpiration() time.Time { - return j.Expiration.AsTime() +func (j *JWTProfileAssertionClaims) MarshalJSON() ([]byte, error) { + return mergeAndMarshalClaims((*jpaAlias)(j), j.Claims) } -func (j *jwtProfileAssertion) GetIssuedAt() time.Time { - return j.IssuedAt.AsTime() +func (j *JWTProfileAssertionClaims) UnmarshalJSON(data []byte) error { + return unmarshalJSONMulti(data, (*jpaAlias)(j), &j.Claims) } -func NewJWTProfileAssertionFromKeyJSON(filename string, audience []string, opts ...AssertionOption) (JWTProfileAssertionClaims, error) { - data, err := ioutil.ReadFile(filename) +func NewJWTProfileAssertionFromKeyJSON(filename string, audience []string, opts ...AssertionOption) (*JWTProfileAssertionClaims, error) { + data, err := os.ReadFile(filename) if err != nil { return nil, err } @@ -305,19 +222,19 @@ func NewJWTProfileAssertionStringFromFileData(data []byte, audience []string, op return GenerateJWTProfileToken(NewJWTProfileAssertion(keyData.UserID, keyData.KeyID, audience, []byte(keyData.Key), opts...)) } -func JWTProfileDelegatedSubject(sub string) func(*jwtProfileAssertion) { - return func(j *jwtProfileAssertion) { +func JWTProfileDelegatedSubject(sub string) func(*JWTProfileAssertionClaims) { + return func(j *JWTProfileAssertionClaims) { j.Subject = sub } } -func JWTProfileCustomClaim(key string, value interface{}) func(*jwtProfileAssertion) { - return func(j *jwtProfileAssertion) { - j.customClaims[key] = value +func JWTProfileCustomClaim(key string, value interface{}) func(*JWTProfileAssertionClaims) { + return func(j *JWTProfileAssertionClaims) { + j.Claims[key] = value } } -func NewJWTProfileAssertionFromFileData(data []byte, audience []string, opts ...AssertionOption) (JWTProfileAssertionClaims, error) { +func NewJWTProfileAssertionFromFileData(data []byte, audience []string, opts ...AssertionOption) (*JWTProfileAssertionClaims, error) { keyData := new(struct { KeyID string `json:"keyId"` Key string `json:"key"` @@ -330,10 +247,10 @@ func NewJWTProfileAssertionFromFileData(data []byte, audience []string, opts ... return NewJWTProfileAssertion(keyData.UserID, keyData.KeyID, audience, []byte(keyData.Key), opts...), nil } -type AssertionOption func(*jwtProfileAssertion) +type AssertionOption func(*JWTProfileAssertionClaims) -func NewJWTProfileAssertion(userID, keyID string, audience []string, key []byte, opts ...AssertionOption) JWTProfileAssertionClaims { - j := &jwtProfileAssertion{ +func NewJWTProfileAssertion(userID, keyID string, audience []string, key []byte, opts ...AssertionOption) *JWTProfileAssertionClaims { + j := &JWTProfileAssertionClaims{ PrivateKey: key, PrivateKeyID: keyID, Issuer: userID, @@ -341,7 +258,7 @@ func NewJWTProfileAssertion(userID, keyID string, audience []string, key []byte, IssuedAt: FromTime(time.Now().UTC()), Expiration: FromTime(time.Now().Add(1 * time.Hour).UTC()), Audience: audience, - customClaims: make(map[string]interface{}), + Claims: make(map[string]interface{}), } for _, opt := range opts { @@ -369,14 +286,14 @@ func AppendClientIDToAudience(clientID string, audience []string) []string { return append(audience, clientID) } -func GenerateJWTProfileToken(assertion JWTProfileAssertionClaims) (string, error) { - privateKey, err := crypto.BytesToPrivateKey(assertion.GetPrivateKey()) +func GenerateJWTProfileToken(assertion *JWTProfileAssertionClaims) (string, error) { + privateKey, err := crypto.BytesToPrivateKey(assertion.PrivateKey) if err != nil { return "", err } key := jose.SigningKey{ Algorithm: jose.RS256, - Key: &jose.JSONWebKey{Key: privateKey, KeyID: assertion.GetKeyID()}, + Key: &jose.JSONWebKey{Key: privateKey, KeyID: assertion.PrivateKeyID}, } signer, err := jose.NewSigner(key, &jose.SignerOptions{}) if err != nil {