Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Token Exchange (RFC 8693) #255

Merged
merged 7 commits into from
Feb 19, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 25 additions & 0 deletions example/client/app/app.go
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,31 @@ func main() {
// w.Write(data)
//}

// you can also try token exchange flow
//
// requestTokenExchange := func(w http.ResponseWriter, r *http.Request, tokens *oidc.Tokens, state string, rp rp.RelyingParty, info oidc.UserInfo) {
// data := make(url.Values)
// data.Set("grant_type", string(oidc.GrantTypeTokenExchange))
// data.Set("requested_token_type", string(oidc.IDTokenType))
// data.Set("subject_token", tokens.RefreshToken)
// data.Set("subject_token_type", string(oidc.RefreshTokenType))
// data.Add("scope", "profile custom_scope:impersonate:id2")

// client := &http.Client{}
// r2, _ := http.NewRequest(http.MethodPost, issuer+"/oauth/token", strings.NewReader(data.Encode()))
// // r2.Header.Add("Authorization", "Basic "+"d2ViOnNlY3JldA==")
// r2.Header.Add("Content-Type", "application/x-www-form-urlencoded")
// r2.SetBasicAuth("web", "secret")

// resp, _ := client.Do(r2)
// fmt.Println(resp.Status)

// b, _ := io.ReadAll(resp.Body)
// resp.Body.Close()

// w.Write(b)
// }

// register the CodeExchangeHandler at the callbackPath
// the CodeExchangeHandler handles the auth response, creates the token request and calls the callback function
// with the returned tokens from the token endpoint
Expand Down
4 changes: 2 additions & 2 deletions example/server/storage/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ func NativeClient(id string, redirectURIs ...string) *Client {
loginURL: defaultLoginURL,
responseTypes: []oidc.ResponseType{oidc.ResponseTypeCode},
grantTypes: []oidc.GrantType{oidc.GrantTypeCode, oidc.GrantTypeRefreshToken},
accessTokenType: 0,
accessTokenType: op.AccessTokenTypeBearer,
devMode: false,
idTokenUserinfoClaimsAssertion: false,
clockSkew: 0,
Expand All @@ -184,7 +184,7 @@ func WebClient(id, secret string, redirectURIs ...string) *Client {
loginURL: defaultLoginURL,
responseTypes: []oidc.ResponseType{oidc.ResponseTypeCode},
grantTypes: []oidc.GrantType{oidc.GrantTypeCode, oidc.GrantTypeRefreshToken},
accessTokenType: 0,
accessTokenType: op.AccessTokenTypeBearer,
devMode: false,
idTokenUserinfoClaimsAssertion: false,
clockSkew: 0,
Expand Down
3 changes: 3 additions & 0 deletions example/server/storage/oidc.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@ const (

// CustomClaim is an example for how to return custom claims with this library
CustomClaim = "custom_claim"

// CustomScopeImpersonatePrefix is an example scope prefix for passing user id to impersonate using token exchage
CustomScopeImpersonatePrefix = "custom_scope:impersonate:"
)

type AuthRequest struct {
Expand Down
135 changes: 131 additions & 4 deletions example/server/storage/storage.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,10 @@ import (
"context"
"crypto/rand"
"crypto/rsa"
"errors"
"fmt"
"math/big"
"strings"
"sync"
"time"

Expand Down Expand Up @@ -213,11 +215,14 @@ func (s *Storage) DeleteAuthRequest(ctx context.Context, id string) error {
// it will be called for all requests able to return an access token (Authorization Code Flow, Implicit Flow, JWT Profile, ...)
func (s *Storage) CreateAccessToken(ctx context.Context, request op.TokenRequest) (string, time.Time, error) {
var applicationID string
// if authenticated for an app (auth code / implicit flow) we must save the client_id to the token
authReq, ok := request.(*AuthRequest)
if ok {
applicationID = authReq.ApplicationID
switch req := request.(type) {
case *AuthRequest:
// if authenticated for an app (auth code / implicit flow) we must save the client_id to the token
applicationID = req.ApplicationID
case op.TokenExchangeRequest:
applicationID = req.GetClientID()
}

token, err := s.accessToken(applicationID, "", request.GetSubject(), request.GetAudience(), request.GetScopes())
if err != nil {
return "", time.Time{}, err
Expand All @@ -228,6 +233,11 @@ func (s *Storage) CreateAccessToken(ctx context.Context, request op.TokenRequest
// CreateAccessAndRefreshTokens implements the op.Storage interface
// it will be called for all requests able to return an access and refresh token (Authorization Code Flow, Refresh Token Request)
func (s *Storage) CreateAccessAndRefreshTokens(ctx context.Context, request op.TokenRequest, currentRefreshToken string) (accessTokenID string, newRefreshToken string, expiration time.Time, err error) {
// generate tokens via token exchange flow if request is relevant
if teReq, ok := request.(op.TokenExchangeRequest); ok {
return s.exchangeRefreshToken(ctx, teReq)
}

// get the information depending on the request type / implementation
applicationID, authTime, amr := getInfoFromRequest(request)

Expand Down Expand Up @@ -258,6 +268,24 @@ func (s *Storage) CreateAccessAndRefreshTokens(ctx context.Context, request op.T
return accessToken.ID, refreshToken, accessToken.Expiration, nil
}

func (s *Storage) exchangeRefreshToken(ctx context.Context, request op.TokenExchangeRequest) (accessTokenID string, newRefreshToken string, expiration time.Time, err error) {
applicationID := request.GetClientID()
authTime := request.GetAuthTime()

refreshTokenID := uuid.NewString()
accessToken, err := s.accessToken(applicationID, refreshTokenID, request.GetSubject(), request.GetAudience(), request.GetScopes())
if err != nil {
return "", "", time.Time{}, err
}

refreshToken, err := s.createRefreshToken(accessToken, nil, authTime)
if err != nil {
return "", "", time.Time{}, err
}

return accessToken.ID, refreshToken, accessToken.Expiration, nil
}

// TokenRequestByRefreshToken implements the op.Storage interface
// it will be called after parsing and validation of the refresh token request
func (s *Storage) TokenRequestByRefreshToken(ctx context.Context, refreshToken string) (op.RefreshTokenRequest, error) {
Expand Down Expand Up @@ -444,6 +472,10 @@ func (s *Storage) SetIntrospectionFromToken(ctx context.Context, introspection o
// GetPrivateClaimsFromScopes implements the op.Storage interface
// it will be called for the creation of a JWT access token to assert claims for custom scopes
func (s *Storage) GetPrivateClaimsFromScopes(ctx context.Context, userID, clientID string, scopes []string) (claims map[string]interface{}, err error) {
return s.getPrivateClaimsFromScopes(ctx, userID, clientID, scopes)
}

func (s *Storage) getPrivateClaimsFromScopes(ctx context.Context, userID, clientID string, scopes []string) (claims map[string]interface{}, err error) {
muhlemmer marked this conversation as resolved.
Show resolved Hide resolved
for _, scope := range scopes {
switch scope {
case CustomScope:
Expand Down Expand Up @@ -580,6 +612,101 @@ func (s *Storage) setUserinfo(ctx context.Context, userInfo oidc.UserInfoSetter,
return nil
}

// ValidateTokenExchangeRequest implements the op.TokenExchangeStorage interface
// it will be called to validate parsed Token Exchange Grant request
func (s *Storage) ValidateTokenExchangeRequest(ctx context.Context, request op.TokenExchangeRequest) error {
if request.GetRequestedTokenType() == "" {
request.SetRequestedTokenType(oidc.RefreshTokenType)
}

// Just an example, some use cases might need this use case
if request.GetExchangeSubjectTokenType() == oidc.IDTokenType && request.GetRequestedTokenType() == oidc.RefreshTokenType {
return errors.New("exchanging id_token to refresh_token is not supported")
}

// Check impersonation permissions
if request.GetExchangeActor() == "" && !s.userStore.GetUserByID(request.GetExchangeSubject()).IsAdmin {
return errors.New("user doesn't have impersonation permission")
}

allowedScopes := make([]string, 0)
for _, scope := range request.GetScopes() {
if scope == oidc.ScopeAddress {
continue
}

if strings.HasPrefix(scope, CustomScopeImpersonatePrefix) {
subject := strings.TrimPrefix(scope, CustomScopeImpersonatePrefix)
request.SetSubject(subject)
}

allowedScopes = append(allowedScopes, scope)
}

request.SetCurrentScopes(allowedScopes)

return nil
}

// ValidateTokenExchangeRequest implements the op.TokenExchangeStorage interface
// Common use case is to store request for audit purposes. For this example we skip the storing.
func (s *Storage) CreateTokenExchangeRequest(ctx context.Context, request op.TokenExchangeRequest) error {
return nil
}

// GetPrivateClaimsFromScopesForTokenExchange implements the op.TokenExchangeStorage interface
// it will be called for the creation of an exchanged JWT access token to assert claims for custom scopes
// plus adding token exchange specific claims related to delegation or impersonation
func (s *Storage) GetPrivateClaimsFromTokenExchangeRequest(ctx context.Context, request op.TokenExchangeRequest) (claims map[string]interface{}, err error) {
claims, err = s.getPrivateClaimsFromScopes(ctx, "", request.GetClientID(), request.GetScopes())
if err != nil {
return nil, err
}

for k, v := range s.getTokenExchangeClaims(ctx, request) {
claims = appendClaim(claims, k, v)
}

return claims, nil
}

// SetUserinfoFromScopesForTokenExchange implements the op.TokenExchangeStorage interface
// it will be called for the creation of an id_token - we are using the same private function as for other flows,
// plus adding token exchange specific claims related to delegation or impersonation
func (s *Storage) SetUserinfoFromTokenExchangeRequest(ctx context.Context, userinfo oidc.UserInfoSetter, request op.TokenExchangeRequest) error {
err := s.setUserinfo(ctx, userinfo, request.GetSubject(), request.GetClientID(), request.GetScopes())
if err != nil {
return err
}

for k, v := range s.getTokenExchangeClaims(ctx, request) {
userinfo.AppendClaims(k, v)
}

return nil
}

func (s *Storage) getTokenExchangeClaims(ctx context.Context, request op.TokenExchangeRequest) (claims map[string]interface{}) {
for _, scope := range request.GetScopes() {
switch {
case strings.HasPrefix(scope, CustomScopeImpersonatePrefix) && request.GetExchangeActor() == "":
// Set actor subject claim for impersonation flow
claims = appendClaim(claims, "act", map[string]interface{}{
"sub": request.GetExchangeSubject(),
})
}
}

// Set actor subject claim for delegation flow
// if request.GetExchangeActor() != "" {
// claims = appendClaim(claims, "act", map[string]interface{}{
// "sub": request.GetExchangeActor(),
// })
// }

return claims
}

// getInfoFromRequest returns the clientID, authTime and amr depending on the op.TokenRequest type / implementation
func getInfoFromRequest(req op.TokenRequest) (clientID string, authTime time.Time, amr []string) {
authReq, ok := req.(*AuthRequest) // Code Flow (with scope offline_access)
Expand Down
15 changes: 15 additions & 0 deletions example/server/storage/user.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ type User struct {
Phone string
PhoneVerified bool
PreferredLanguage language.Tag
IsAdmin bool
}

type Service struct {
Expand Down Expand Up @@ -49,6 +50,20 @@ func NewUserStore(issuer string) UserStore {
Phone: "",
PhoneVerified: false,
PreferredLanguage: language.German,
IsAdmin: true,
},
"id2": {
ID: "id2",
Username: "test-user2",
Password: "verysecure",
FirstName: "Test",
LastName: "User2",
Email: "test-user2@zitadel.ch",
EmailVerified: true,
Phone: "",
PhoneVerified: false,
PreferredLanguage: language.German,
IsAdmin: false,
},
},
}
Expand Down
15 changes: 15 additions & 0 deletions pkg/client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,9 @@ func CallEndSessionEndpoint(request interface{}, authFn interface{}, caller EndS
return http.ErrUseLastResponse
}
resp, err := client.Do(req)
if err != nil {
return nil, err
}
defer resp.Body.Close()
if resp.StatusCode < 200 || resp.StatusCode >= 400 {
body, err := io.ReadAll(resp.Body)
Expand Down Expand Up @@ -148,6 +151,18 @@ func CallRevokeEndpoint(request interface{}, authFn interface{}, caller RevokeCa
return nil
}

func CallTokenExchangeEndpoint(request interface{}, authFn interface{}, caller TokenEndpointCaller) (resp *oidc.TokenExchangeResponse, err error) {
req, err := httphelper.FormRequest(caller.TokenEndpoint(), request, Encoder, authFn)
if err != nil {
return nil, err
}
tokenRes := new(oidc.TokenExchangeResponse)
if err := httphelper.HttpRequest(caller.HttpClient(), req, &tokenRes); err != nil {
return nil, err
}
return tokenRes, nil
}

func NewSignerFromPrivateKeyByte(key []byte, keyID string) (jose.Signer, error) {
privateKey, err := crypto.BytesToPrivateKey(key)
if err != nil {
Expand Down
Loading