Skip to content

Commit

Permalink
oidc: refactor JWTProfileAssertionClaims
Browse files Browse the repository at this point in the history
  • Loading branch information
muhlemmer committed Mar 2, 2023
1 parent 3d4f358 commit c3c7031
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 108 deletions.
4 changes: 2 additions & 2 deletions pkg/oidc/regression_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -149,15 +149,15 @@ var (
"foo": "bar",
},
}
jwtProfileAssertionRegressData = &jwtProfileAssertion{
jwtProfileAssertionRegressData = &JWTProfileAssertionClaims{
PrivateKeyID: "8888",
PrivateKey: []byte("qwerty"),
Issuer: "zitadel",
Subject: "hello@me.com",
Audience: Audience{"foo", "bar"},
Expiration: 12345,
IssuedAt: 12000,
customClaims: map[string]interface{}{
Claims: map[string]interface{}{
"foo": "bar",
},
}
Expand Down
129 changes: 23 additions & 106 deletions pkg/oidc/token.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,7 @@ package oidc

import (
"encoding/json"
"fmt"
"io/ioutil"
"os"
"time"

"golang.org/x/oauth2"
Expand Down Expand Up @@ -180,19 +179,7 @@ type AccessTokenResponse struct {
State string `json:"state,omitempty" schema:"state,omitempty"`
}

type JWTProfileAssertionClaims interface {
GetKeyID() string
GetPrivateKey() []byte
GetIssuer() string
GetSubject() string
GetAudience() []string
GetExpiration() time.Time
GetIssuedAt() time.Time
SetCustomClaim(key string, value interface{})
GetCustomClaim(key string) interface{}
}

type jwtProfileAssertion struct {
type JWTProfileAssertionClaims struct {
PrivateKeyID string `json:"-"`
PrivateKey []byte `json:"-"`
Issuer string `json:"iss"`
Expand All @@ -201,91 +188,21 @@ type jwtProfileAssertion struct {
Expiration Time `json:"exp"`
IssuedAt Time `json:"iat"`

customClaims map[string]interface{}
}

func (j *jwtProfileAssertion) MarshalJSON() ([]byte, error) {
type Alias jwtProfileAssertion
a := (*Alias)(j)

b, err := json.Marshal(a)
if err != nil {
return nil, err
}

if len(j.customClaims) == 0 {
return b, nil
}

err = json.Unmarshal(b, &j.customClaims)
if err != nil {
return nil, fmt.Errorf("jws: invalid map of custom claims %v", j.customClaims)
}

return json.Marshal(j.customClaims)
}

func (j *jwtProfileAssertion) UnmarshalJSON(data []byte) error {
type Alias jwtProfileAssertion
a := (*Alias)(j)

err := json.Unmarshal(data, a)
if err != nil {
return err
}

err = json.Unmarshal(data, &j.customClaims)
if err != nil {
return err
}

return nil
}

func (j *jwtProfileAssertion) GetKeyID() string {
return j.PrivateKeyID
}

func (j *jwtProfileAssertion) GetPrivateKey() []byte {
return j.PrivateKey
Claims map[string]interface{} `json:"-"`
}

func (j *jwtProfileAssertion) SetCustomClaim(key string, value interface{}) {
if j.customClaims == nil {
j.customClaims = make(map[string]interface{})
}
j.customClaims[key] = value
}

func (j *jwtProfileAssertion) GetCustomClaim(key string) interface{} {
if j.customClaims == nil {
return nil
}
return j.customClaims[key]
}

func (j *jwtProfileAssertion) GetIssuer() string {
return j.Issuer
}

func (j *jwtProfileAssertion) GetSubject() string {
return j.Subject
}

func (j *jwtProfileAssertion) GetAudience() []string {
return j.Audience
}
type jpaAlias JWTProfileAssertionClaims

func (j *jwtProfileAssertion) GetExpiration() time.Time {
return j.Expiration.AsTime()
func (j *JWTProfileAssertionClaims) MarshalJSON() ([]byte, error) {
return mergeAndMarshalClaims((*jpaAlias)(j), j.Claims)
}

func (j *jwtProfileAssertion) GetIssuedAt() time.Time {
return j.IssuedAt.AsTime()
func (j *JWTProfileAssertionClaims) UnmarshalJSON(data []byte) error {
return unmarshalJSONMulti(data, (*jpaAlias)(j), &j.Claims)
}

func NewJWTProfileAssertionFromKeyJSON(filename string, audience []string, opts ...AssertionOption) (JWTProfileAssertionClaims, error) {
data, err := ioutil.ReadFile(filename)
func NewJWTProfileAssertionFromKeyJSON(filename string, audience []string, opts ...AssertionOption) (*JWTProfileAssertionClaims, error) {
data, err := os.ReadFile(filename)
if err != nil {
return nil, err
}
Expand All @@ -305,19 +222,19 @@ func NewJWTProfileAssertionStringFromFileData(data []byte, audience []string, op
return GenerateJWTProfileToken(NewJWTProfileAssertion(keyData.UserID, keyData.KeyID, audience, []byte(keyData.Key), opts...))
}

func JWTProfileDelegatedSubject(sub string) func(*jwtProfileAssertion) {
return func(j *jwtProfileAssertion) {
func JWTProfileDelegatedSubject(sub string) func(*JWTProfileAssertionClaims) {
return func(j *JWTProfileAssertionClaims) {
j.Subject = sub
}
}

func JWTProfileCustomClaim(key string, value interface{}) func(*jwtProfileAssertion) {
return func(j *jwtProfileAssertion) {
j.customClaims[key] = value
func JWTProfileCustomClaim(key string, value interface{}) func(*JWTProfileAssertionClaims) {
return func(j *JWTProfileAssertionClaims) {
j.Claims[key] = value
}
}

func NewJWTProfileAssertionFromFileData(data []byte, audience []string, opts ...AssertionOption) (JWTProfileAssertionClaims, error) {
func NewJWTProfileAssertionFromFileData(data []byte, audience []string, opts ...AssertionOption) (*JWTProfileAssertionClaims, error) {
keyData := new(struct {
KeyID string `json:"keyId"`
Key string `json:"key"`
Expand All @@ -330,18 +247,18 @@ func NewJWTProfileAssertionFromFileData(data []byte, audience []string, opts ...
return NewJWTProfileAssertion(keyData.UserID, keyData.KeyID, audience, []byte(keyData.Key), opts...), nil
}

type AssertionOption func(*jwtProfileAssertion)
type AssertionOption func(*JWTProfileAssertionClaims)

func NewJWTProfileAssertion(userID, keyID string, audience []string, key []byte, opts ...AssertionOption) JWTProfileAssertionClaims {
j := &jwtProfileAssertion{
func NewJWTProfileAssertion(userID, keyID string, audience []string, key []byte, opts ...AssertionOption) *JWTProfileAssertionClaims {
j := &JWTProfileAssertionClaims{
PrivateKey: key,
PrivateKeyID: keyID,
Issuer: userID,
Subject: userID,
IssuedAt: FromTime(time.Now().UTC()),
Expiration: FromTime(time.Now().Add(1 * time.Hour).UTC()),
Audience: audience,
customClaims: make(map[string]interface{}),
Claims: make(map[string]interface{}),
}

for _, opt := range opts {
Expand Down Expand Up @@ -369,14 +286,14 @@ func AppendClientIDToAudience(clientID string, audience []string) []string {
return append(audience, clientID)
}

func GenerateJWTProfileToken(assertion JWTProfileAssertionClaims) (string, error) {
privateKey, err := crypto.BytesToPrivateKey(assertion.GetPrivateKey())
func GenerateJWTProfileToken(assertion *JWTProfileAssertionClaims) (string, error) {
privateKey, err := crypto.BytesToPrivateKey(assertion.PrivateKey)
if err != nil {
return "", err
}
key := jose.SigningKey{
Algorithm: jose.RS256,
Key: &jose.JSONWebKey{Key: privateKey, KeyID: assertion.GetKeyID()},
Key: &jose.JSONWebKey{Key: privateKey, KeyID: assertion.PrivateKeyID},
}
signer, err := jose.NewSigner(key, &jose.SignerOptions{})
if err != nil {
Expand Down

0 comments on commit c3c7031

Please sign in to comment.