Skip to content

Commit

Permalink
feat: Token Exchange (RFC 8693) (#255)
Browse files Browse the repository at this point in the history
This change implements OAuth2 Token Exchange in OP according to RFC 8693 (and client code)

Some implementation details:

- OP parses and verifies subject/actor tokens natively if they were issued by OP
- Third-party tokens verification is also possible by implementing additional storage interface
- Token exchange can issue only OP's native tokens (id_token, access_token and refresh_token) with static issuer
  • Loading branch information
lefelys authored Feb 19, 2023
1 parent 9291ca9 commit 8e29879
Show file tree
Hide file tree
Showing 16 changed files with 960 additions and 58 deletions.
25 changes: 25 additions & 0 deletions example/client/app/app.go
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,31 @@ func main() {
// w.Write(data)
//}

// you can also try token exchange flow
//
// requestTokenExchange := func(w http.ResponseWriter, r *http.Request, tokens *oidc.Tokens, state string, rp rp.RelyingParty, info oidc.UserInfo) {
// data := make(url.Values)
// data.Set("grant_type", string(oidc.GrantTypeTokenExchange))
// data.Set("requested_token_type", string(oidc.IDTokenType))
// data.Set("subject_token", tokens.RefreshToken)
// data.Set("subject_token_type", string(oidc.RefreshTokenType))
// data.Add("scope", "profile custom_scope:impersonate:id2")

// client := &http.Client{}
// r2, _ := http.NewRequest(http.MethodPost, issuer+"/oauth/token", strings.NewReader(data.Encode()))
// // r2.Header.Add("Authorization", "Basic "+"d2ViOnNlY3JldA==")
// r2.Header.Add("Content-Type", "application/x-www-form-urlencoded")
// r2.SetBasicAuth("web", "secret")

// resp, _ := client.Do(r2)
// fmt.Println(resp.Status)

// b, _ := io.ReadAll(resp.Body)
// resp.Body.Close()

// w.Write(b)
// }

// register the CodeExchangeHandler at the callbackPath
// the CodeExchangeHandler handles the auth response, creates the token request and calls the callback function
// with the returned tokens from the token endpoint
Expand Down
4 changes: 2 additions & 2 deletions example/server/storage/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ func NativeClient(id string, redirectURIs ...string) *Client {
loginURL: defaultLoginURL,
responseTypes: []oidc.ResponseType{oidc.ResponseTypeCode},
grantTypes: []oidc.GrantType{oidc.GrantTypeCode, oidc.GrantTypeRefreshToken},
accessTokenType: 0,
accessTokenType: op.AccessTokenTypeBearer,
devMode: false,
idTokenUserinfoClaimsAssertion: false,
clockSkew: 0,
Expand All @@ -184,7 +184,7 @@ func WebClient(id, secret string, redirectURIs ...string) *Client {
loginURL: defaultLoginURL,
responseTypes: []oidc.ResponseType{oidc.ResponseTypeCode},
grantTypes: []oidc.GrantType{oidc.GrantTypeCode, oidc.GrantTypeRefreshToken},
accessTokenType: 0,
accessTokenType: op.AccessTokenTypeBearer,
devMode: false,
idTokenUserinfoClaimsAssertion: false,
clockSkew: 0,
Expand Down
3 changes: 3 additions & 0 deletions example/server/storage/oidc.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@ const (

// CustomClaim is an example for how to return custom claims with this library
CustomClaim = "custom_claim"

// CustomScopeImpersonatePrefix is an example scope prefix for passing user id to impersonate using token exchage
CustomScopeImpersonatePrefix = "custom_scope:impersonate:"
)

type AuthRequest struct {
Expand Down
135 changes: 131 additions & 4 deletions example/server/storage/storage.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,10 @@ import (
"context"
"crypto/rand"
"crypto/rsa"
"errors"
"fmt"
"math/big"
"strings"
"sync"
"time"

Expand Down Expand Up @@ -213,11 +215,14 @@ func (s *Storage) DeleteAuthRequest(ctx context.Context, id string) error {
// it will be called for all requests able to return an access token (Authorization Code Flow, Implicit Flow, JWT Profile, ...)
func (s *Storage) CreateAccessToken(ctx context.Context, request op.TokenRequest) (string, time.Time, error) {
var applicationID string
// if authenticated for an app (auth code / implicit flow) we must save the client_id to the token
authReq, ok := request.(*AuthRequest)
if ok {
applicationID = authReq.ApplicationID
switch req := request.(type) {
case *AuthRequest:
// if authenticated for an app (auth code / implicit flow) we must save the client_id to the token
applicationID = req.ApplicationID
case op.TokenExchangeRequest:
applicationID = req.GetClientID()
}

token, err := s.accessToken(applicationID, "", request.GetSubject(), request.GetAudience(), request.GetScopes())
if err != nil {
return "", time.Time{}, err
Expand All @@ -228,6 +233,11 @@ func (s *Storage) CreateAccessToken(ctx context.Context, request op.TokenRequest
// CreateAccessAndRefreshTokens implements the op.Storage interface
// it will be called for all requests able to return an access and refresh token (Authorization Code Flow, Refresh Token Request)
func (s *Storage) CreateAccessAndRefreshTokens(ctx context.Context, request op.TokenRequest, currentRefreshToken string) (accessTokenID string, newRefreshToken string, expiration time.Time, err error) {
// generate tokens via token exchange flow if request is relevant
if teReq, ok := request.(op.TokenExchangeRequest); ok {
return s.exchangeRefreshToken(ctx, teReq)
}

// get the information depending on the request type / implementation
applicationID, authTime, amr := getInfoFromRequest(request)

Expand Down Expand Up @@ -258,6 +268,24 @@ func (s *Storage) CreateAccessAndRefreshTokens(ctx context.Context, request op.T
return accessToken.ID, refreshToken, accessToken.Expiration, nil
}

func (s *Storage) exchangeRefreshToken(ctx context.Context, request op.TokenExchangeRequest) (accessTokenID string, newRefreshToken string, expiration time.Time, err error) {
applicationID := request.GetClientID()
authTime := request.GetAuthTime()

refreshTokenID := uuid.NewString()
accessToken, err := s.accessToken(applicationID, refreshTokenID, request.GetSubject(), request.GetAudience(), request.GetScopes())
if err != nil {
return "", "", time.Time{}, err
}

refreshToken, err := s.createRefreshToken(accessToken, nil, authTime)
if err != nil {
return "", "", time.Time{}, err
}

return accessToken.ID, refreshToken, accessToken.Expiration, nil
}

// TokenRequestByRefreshToken implements the op.Storage interface
// it will be called after parsing and validation of the refresh token request
func (s *Storage) TokenRequestByRefreshToken(ctx context.Context, refreshToken string) (op.RefreshTokenRequest, error) {
Expand Down Expand Up @@ -444,6 +472,10 @@ func (s *Storage) SetIntrospectionFromToken(ctx context.Context, introspection o
// GetPrivateClaimsFromScopes implements the op.Storage interface
// it will be called for the creation of a JWT access token to assert claims for custom scopes
func (s *Storage) GetPrivateClaimsFromScopes(ctx context.Context, userID, clientID string, scopes []string) (claims map[string]interface{}, err error) {
return s.getPrivateClaimsFromScopes(ctx, userID, clientID, scopes)
}

func (s *Storage) getPrivateClaimsFromScopes(ctx context.Context, userID, clientID string, scopes []string) (claims map[string]interface{}, err error) {
for _, scope := range scopes {
switch scope {
case CustomScope:
Expand Down Expand Up @@ -580,6 +612,101 @@ func (s *Storage) setUserinfo(ctx context.Context, userInfo oidc.UserInfoSetter,
return nil
}

// ValidateTokenExchangeRequest implements the op.TokenExchangeStorage interface
// it will be called to validate parsed Token Exchange Grant request
func (s *Storage) ValidateTokenExchangeRequest(ctx context.Context, request op.TokenExchangeRequest) error {
if request.GetRequestedTokenType() == "" {
request.SetRequestedTokenType(oidc.RefreshTokenType)
}

// Just an example, some use cases might need this use case
if request.GetExchangeSubjectTokenType() == oidc.IDTokenType && request.GetRequestedTokenType() == oidc.RefreshTokenType {
return errors.New("exchanging id_token to refresh_token is not supported")
}

// Check impersonation permissions
if request.GetExchangeActor() == "" && !s.userStore.GetUserByID(request.GetExchangeSubject()).IsAdmin {
return errors.New("user doesn't have impersonation permission")
}

allowedScopes := make([]string, 0)
for _, scope := range request.GetScopes() {
if scope == oidc.ScopeAddress {
continue
}

if strings.HasPrefix(scope, CustomScopeImpersonatePrefix) {
subject := strings.TrimPrefix(scope, CustomScopeImpersonatePrefix)
request.SetSubject(subject)
}

allowedScopes = append(allowedScopes, scope)
}

request.SetCurrentScopes(allowedScopes)

return nil
}

// ValidateTokenExchangeRequest implements the op.TokenExchangeStorage interface
// Common use case is to store request for audit purposes. For this example we skip the storing.
func (s *Storage) CreateTokenExchangeRequest(ctx context.Context, request op.TokenExchangeRequest) error {
return nil
}

// GetPrivateClaimsFromScopesForTokenExchange implements the op.TokenExchangeStorage interface
// it will be called for the creation of an exchanged JWT access token to assert claims for custom scopes
// plus adding token exchange specific claims related to delegation or impersonation
func (s *Storage) GetPrivateClaimsFromTokenExchangeRequest(ctx context.Context, request op.TokenExchangeRequest) (claims map[string]interface{}, err error) {
claims, err = s.getPrivateClaimsFromScopes(ctx, "", request.GetClientID(), request.GetScopes())
if err != nil {
return nil, err
}

for k, v := range s.getTokenExchangeClaims(ctx, request) {
claims = appendClaim(claims, k, v)
}

return claims, nil
}

// SetUserinfoFromScopesForTokenExchange implements the op.TokenExchangeStorage interface
// it will be called for the creation of an id_token - we are using the same private function as for other flows,
// plus adding token exchange specific claims related to delegation or impersonation
func (s *Storage) SetUserinfoFromTokenExchangeRequest(ctx context.Context, userinfo oidc.UserInfoSetter, request op.TokenExchangeRequest) error {
err := s.setUserinfo(ctx, userinfo, request.GetSubject(), request.GetClientID(), request.GetScopes())
if err != nil {
return err
}

for k, v := range s.getTokenExchangeClaims(ctx, request) {
userinfo.AppendClaims(k, v)
}

return nil
}

func (s *Storage) getTokenExchangeClaims(ctx context.Context, request op.TokenExchangeRequest) (claims map[string]interface{}) {
for _, scope := range request.GetScopes() {
switch {
case strings.HasPrefix(scope, CustomScopeImpersonatePrefix) && request.GetExchangeActor() == "":
// Set actor subject claim for impersonation flow
claims = appendClaim(claims, "act", map[string]interface{}{
"sub": request.GetExchangeSubject(),
})
}
}

// Set actor subject claim for delegation flow
// if request.GetExchangeActor() != "" {
// claims = appendClaim(claims, "act", map[string]interface{}{
// "sub": request.GetExchangeActor(),
// })
// }

return claims
}

// getInfoFromRequest returns the clientID, authTime and amr depending on the op.TokenRequest type / implementation
func getInfoFromRequest(req op.TokenRequest) (clientID string, authTime time.Time, amr []string) {
authReq, ok := req.(*AuthRequest) // Code Flow (with scope offline_access)
Expand Down
15 changes: 15 additions & 0 deletions example/server/storage/user.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ type User struct {
Phone string
PhoneVerified bool
PreferredLanguage language.Tag
IsAdmin bool
}

type Service struct {
Expand Down Expand Up @@ -49,6 +50,20 @@ func NewUserStore(issuer string) UserStore {
Phone: "",
PhoneVerified: false,
PreferredLanguage: language.German,
IsAdmin: true,
},
"id2": {
ID: "id2",
Username: "test-user2",
Password: "verysecure",
FirstName: "Test",
LastName: "User2",
Email: "test-user2@zitadel.ch",
EmailVerified: true,
Phone: "",
PhoneVerified: false,
PreferredLanguage: language.German,
IsAdmin: false,
},
},
}
Expand Down
15 changes: 15 additions & 0 deletions pkg/client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,9 @@ func CallEndSessionEndpoint(request interface{}, authFn interface{}, caller EndS
return http.ErrUseLastResponse
}
resp, err := client.Do(req)
if err != nil {
return nil, err
}
defer resp.Body.Close()
if resp.StatusCode < 200 || resp.StatusCode >= 400 {
body, err := io.ReadAll(resp.Body)
Expand Down Expand Up @@ -148,6 +151,18 @@ func CallRevokeEndpoint(request interface{}, authFn interface{}, caller RevokeCa
return nil
}

func CallTokenExchangeEndpoint(request interface{}, authFn interface{}, caller TokenEndpointCaller) (resp *oidc.TokenExchangeResponse, err error) {
req, err := httphelper.FormRequest(caller.TokenEndpoint(), request, Encoder, authFn)
if err != nil {
return nil, err
}
tokenRes := new(oidc.TokenExchangeResponse)
if err := httphelper.HttpRequest(caller.HttpClient(), req, &tokenRes); err != nil {
return nil, err
}
return tokenRes, nil
}

func NewSignerFromPrivateKeyByte(key []byte, keyID string) (jose.Signer, error) {
privateKey, err := crypto.BytesToPrivateKey(key)
if err != nil {
Expand Down
Loading

0 comments on commit 8e29879

Please sign in to comment.