diff --git a/.github/workflows/issue.yml b/.github/workflows/issue.yml index 86718205..362443de 100644 --- a/.github/workflows/issue.yml +++ b/.github/workflows/issue.yml @@ -10,7 +10,7 @@ jobs: name: Add issue to project runs-on: ubuntu-latest steps: - - uses: actions/add-to-project@v0.4.1 + - uses: actions/add-to-project@v0.5.0 with: # You can target a repository in a different organization # to the issue diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 2abef362..7483b2f7 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -21,11 +21,11 @@ jobs: steps: - uses: actions/checkout@v3 - name: Setup go - uses: actions/setup-go@v3 + uses: actions/setup-go@v4 with: go-version: ${{ matrix.go }} - run: go test -race -v -coverprofile=profile.cov -coverpkg=./pkg/... ./pkg/... - - uses: codecov/codecov-action@v3.1.1 + - uses: codecov/codecov-action@v3.1.2 with: file: ./profile.cov name: codecov-go diff --git a/example/server/exampleop/op.go b/example/server/exampleop/op.go index 120590fa..298bff69 100644 --- a/example/server/exampleop/op.go +++ b/example/server/exampleop/op.go @@ -108,7 +108,7 @@ func newOP(storage op.Storage, issuer string, key [32]byte) (op.OpenIDProvider, DeviceAuthorization: op.DeviceAuthorizationConfig{ Lifetime: 5 * time.Minute, PollInterval: 5 * time.Second, - UserFormURL: issuer + "device", + UserFormPath: "/device", UserCode: op.UserCodeBase20, }, } diff --git a/example/server/storage/client.go b/example/server/storage/client.go index 300ce0a2..a3e7cc45 100644 --- a/example/server/storage/client.go +++ b/example/server/storage/client.go @@ -32,6 +32,8 @@ type Client struct { devMode bool idTokenUserinfoClaimsAssertion bool clockSkew time.Duration + postLogoutRedirectURIGlobs []string + redirectURIGlobs []string } // GetID must return the client_id @@ -44,21 +46,11 @@ func (c *Client) RedirectURIs() []string { return c.redirectURIs } -// RedirectURIGlobs provide wildcarding for additional valid redirects -func (c *Client) RedirectURIGlobs() []string { - return nil -} - // PostLogoutRedirectURIs must return the registered post_logout_redirect_uris for sign-outs func (c *Client) PostLogoutRedirectURIs() []string { return []string{} } -// PostLogoutRedirectURIGlobs provide extra wildcarding for additional valid redirects -func (c *Client) PostLogoutRedirectURIGlobs() []string { - return nil -} - // ApplicationType must return the type of the client (app, native, user agent) func (c *Client) ApplicationType() op.ApplicationType { return c.applicationType @@ -200,3 +192,26 @@ func WebClient(id, secret string, redirectURIs ...string) *Client { clockSkew: 0, } } + +type hasRedirectGlobs struct { + *Client +} + +// RedirectURIGlobs provide wildcarding for additional valid redirects +func (c hasRedirectGlobs) RedirectURIGlobs() []string { + return c.redirectURIGlobs +} + +// PostLogoutRedirectURIGlobs provide extra wildcarding for additional valid redirects +func (c hasRedirectGlobs) PostLogoutRedirectURIGlobs() []string { + return c.postLogoutRedirectURIGlobs +} + +// RedirectGlobsClient wraps the client in a op.HasRedirectGlobs +// only if DevMode is enabled. +func RedirectGlobsClient(client *Client) op.Client { + if client.devMode { + return hasRedirectGlobs{client} + } + return client +} diff --git a/example/server/storage/storage.go b/example/server/storage/storage.go index 2aeefe71..e1160b6b 100644 --- a/example/server/storage/storage.go +++ b/example/server/storage/storage.go @@ -418,7 +418,7 @@ func (s *Storage) GetClientByClientID(ctx context.Context, clientID string) (op. if !ok { return nil, fmt.Errorf("client not found") } - return client, nil + return RedirectGlobsClient(client), nil } // AuthorizeClientIDSecret implements the op.Storage interface @@ -438,10 +438,17 @@ func (s *Storage) AuthorizeClientIDSecret(ctx context.Context, clientID, clientS return nil } -// SetUserinfoFromScopes implements the op.Storage interface -// it will be called for the creation of an id_token, so we'll just pass it to the private function without any further check +// SetUserinfoFromScopes implements the op.Storage interface. +// Provide an empty implementation and use SetUserinfoFromRequest instead. func (s *Storage) SetUserinfoFromScopes(ctx context.Context, userinfo *oidc.UserInfo, userID, clientID string, scopes []string) error { - return s.setUserinfo(ctx, userinfo, userID, clientID, scopes) + return nil +} + +// SetUserinfoFromRequests implements the op.CanSetUserinfoFromRequest interface. In the +// next major release, it will be required for op.Storage. +// It will be called for the creation of an id_token, so we'll just pass it to the private function without any further check +func (s *Storage) SetUserinfoFromRequest(ctx context.Context, userinfo *oidc.UserInfo, token op.IDTokenRequest, scopes []string) error { + return s.setUserinfo(ctx, userinfo, token.GetSubject(), token.GetClientID(), scopes) } // SetUserinfoFromToken implements the op.Storage interface diff --git a/example/server/storage/storage_dynamic.go b/example/server/storage/storage_dynamic.go index 3aec9d72..0d99aa27 100644 --- a/example/server/storage/storage_dynamic.go +++ b/example/server/storage/storage_dynamic.go @@ -196,8 +196,8 @@ func (s *multiStorage) AuthorizeClientIDSecret(ctx context.Context, clientID, cl return storage.AuthorizeClientIDSecret(ctx, clientID, clientSecret) } -// SetUserinfoFromScopes implements the op.Storage interface -// it will be called for the creation of an id_token, so we'll just pass it to the private function without any further check +// SetUserinfoFromScopes implements the op.Storage interface. +// Provide an empty implementation and use SetUserinfoFromRequest instead. func (s *multiStorage) SetUserinfoFromScopes(ctx context.Context, userinfo *oidc.UserInfo, userID, clientID string, scopes []string) error { storage, err := s.storageFromContext(ctx) if err != nil { @@ -206,6 +206,17 @@ func (s *multiStorage) SetUserinfoFromScopes(ctx context.Context, userinfo *oidc return storage.SetUserinfoFromScopes(ctx, userinfo, userID, clientID, scopes) } +// SetUserinfoFromRequests implements the op.CanSetUserinfoFromRequest interface. In the +// next major release, it will be required for op.Storage. +// It will be called for the creation of an id_token, so we'll just pass it to the private function without any further check +func (s *multiStorage) SetUserinfoFromRequest(ctx context.Context, userinfo *oidc.UserInfo, token op.IDTokenRequest, scopes []string) error { + storage, err := s.storageFromContext(ctx) + if err != nil { + return err + } + return storage.SetUserinfoFromRequest(ctx, userinfo, token, scopes) +} + // SetUserinfoFromToken implements the op.Storage interface // it will be called for the userinfo endpoint, so we read the token and pass the information from that to the private function func (s *multiStorage) SetUserinfoFromToken(ctx context.Context, userinfo *oidc.UserInfo, tokenID, subject, origin string) error { diff --git a/go.mod b/go.mod index 8f571576..610d2a10 100644 --- a/go.mod +++ b/go.mod @@ -10,12 +10,12 @@ require ( github.com/gorilla/securecookie v1.1.1 github.com/jeremija/gosubmit v0.2.7 github.com/muhlemmer/gu v0.3.1 - github.com/rs/cors v1.8.3 + github.com/rs/cors v1.9.0 github.com/sirupsen/logrus v1.9.0 github.com/stretchr/testify v1.8.2 github.com/zitadel/schema v1.3.0 - golang.org/x/oauth2 v0.6.0 - golang.org/x/text v0.8.0 + golang.org/x/oauth2 v0.7.0 + golang.org/x/text v0.9.0 gopkg.in/square/go-jose.v2 v2.6.0 ) @@ -26,10 +26,10 @@ require ( github.com/google/go-querystring v1.1.0 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect golang.org/x/crypto v0.7.0 // indirect - golang.org/x/net v0.8.0 // indirect - golang.org/x/sys v0.6.0 // indirect + golang.org/x/net v0.9.0 // indirect + golang.org/x/sys v0.7.0 // indirect google.golang.org/appengine v1.6.7 // indirect - google.golang.org/protobuf v1.29.0 // indirect + google.golang.org/protobuf v1.29.1 // indirect gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c // indirect gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/go.sum b/go.sum index b1757d49..c9c85626 100644 --- a/go.sum +++ b/go.sum @@ -34,8 +34,8 @@ github.com/muhlemmer/gu v0.3.1 h1:7EAqmFrW7n3hETvuAdmFmn4hS8W+z3LgKtrnow+YzNM= github.com/muhlemmer/gu v0.3.1/go.mod h1:YHtHR+gxM+bKEIIs7Hmi9sPT3ZDUvTN/i88wQpZkrdM= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= -github.com/rs/cors v1.8.3 h1:O+qNyWn7Z+F9M0ILBHgMVPuB1xTOucVd5gtaYyXBpRo= -github.com/rs/cors v1.8.3/go.mod h1:XyqrcTp5zjWr1wsJ8PIRZssZ8b/WMcMf71DJnit4EMU= +github.com/rs/cors v1.9.0 h1:l9HGsTsHJcvW14Nk7J9KFz8bzeAWXn3CG6bgt7LsrAE= +github.com/rs/cors v1.9.0/go.mod h1:XyqrcTp5zjWr1wsJ8PIRZssZ8b/WMcMf71DJnit4EMU= github.com/sirupsen/logrus v1.9.0 h1:trlNQbNUG3OdDrDil03MCb1H2o9nJ1x4/5LYw7byDE0= github.com/sirupsen/logrus v1.9.0/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= @@ -60,11 +60,11 @@ golang.org/x/net v0.0.0-20190603091049-60506f45cf65/go.mod h1:HSz+uSET+XFnRR8LxR golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.0.0-20200202094626-16171245cfb2/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.0.0-20210405180319-a5a99cb37ef4/go.mod h1:p54w0d4576C0XHj96bSt6lcn1PtDYWL6XObtHCRCNQM= -golang.org/x/net v0.8.0 h1:Zrh2ngAOFYneWTAIAPethzeaQLuHwhuBkuV6ZiRnUaQ= -golang.org/x/net v0.8.0/go.mod h1:QVkue5JL9kW//ek3r6jTKnTFis1tRmNAW2P1shuFdJc= +golang.org/x/net v0.9.0 h1:aWJ/m6xSmxWBx+V0XRHTlrYrPG56jKsLdTFmsSsCzOM= +golang.org/x/net v0.9.0/go.mod h1:d48xBJpPfHeWQsugry2m+kC02ZBRGRgulfHnEXEuWns= golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= -golang.org/x/oauth2 v0.6.0 h1:Lh8GPgSKBfWSwFvtuWOfeI3aAAnbXTSutYxJiOJFgIw= -golang.org/x/oauth2 v0.6.0/go.mod h1:ycmewcwgD4Rpr3eZJLSB4Kyyljb3qDh40vJ8STE5HKw= +golang.org/x/oauth2 v0.7.0 h1:qe6s0zUXlPX80/dITx3440hWZ7GwMwgDDyrSGTPJG/g= +golang.org/x/oauth2 v0.7.0/go.mod h1:hPLQkd9LyjfXTiRohC/41GhcFqxisoUQ99sCUOHO9x4= golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= @@ -73,14 +73,14 @@ golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7w golang.org/x/sys v0.0.0-20210330210617-4fbd30eecc44/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210510120138-977fb7262007/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.6.0 h1:MVltZSvRTcU2ljQOhs94SXPftV6DCNnZViHeQps87pQ= -golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.7.0 h1:3jlCCIQZPdOYu1h8BkNvLz8Kgwtae2cagcG/VamtZRU= +golang.org/x/sys v0.7.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= -golang.org/x/text v0.8.0 h1:57P1ETyNKtuIjB4SRd15iJxuhj8Gc416Y78H3qgMh68= -golang.org/x/text v0.8.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8= +golang.org/x/text v0.9.0 h1:2sjJmO8cDvYveuX97RDLsxlyUxLl+GHoLxBiRdHllBE= +golang.org/x/text v0.9.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/tools v0.1.1/go.mod h1:o0xws9oXOQQZyjljx8fwUC0k7L1pTE6eaCbjGeHmOkk= @@ -93,8 +93,8 @@ google.golang.org/appengine v1.6.7 h1:FZR1q0exgwxzPzp/aF+VccGrSfxfPpkBqjIIEq3ru6 google.golang.org/appengine v1.6.7/go.mod h1:8WjMMxjGQR8xUklV/ARdw2HLXBOI7O7uCIDZVag1xfc= google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw= google.golang.org/protobuf v1.26.0/go.mod h1:9q0QmTI4eRPtz6boOQmLYwt+qCgq0jsYwAQnmE0givc= -google.golang.org/protobuf v1.29.0 h1:44S3JjaKmLEE4YIkjzexaP+NzZsudE3Zin5Njn/pYX0= -google.golang.org/protobuf v1.29.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I= +google.golang.org/protobuf v1.29.1 h1:7QBf+IK2gx70Ap/hDsOmam3GE0v9HicjfEdAxE62UoM= +google.golang.org/protobuf v1.29.1/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= diff --git a/pkg/client/client.go b/pkg/client/client.go index 37c7ec27..e3efd611 100644 --- a/pkg/client/client.go +++ b/pkg/client/client.go @@ -61,12 +61,18 @@ func callTokenEndpoint(ctx context.Context, request interface{}, authFn interfac if err := httphelper.HttpRequest(caller.HttpClient(), req, &tokenRes); err != nil { return nil, err } - return &oauth2.Token{ + token := &oauth2.Token{ AccessToken: tokenRes.AccessToken, TokenType: tokenRes.TokenType, RefreshToken: tokenRes.RefreshToken, Expiry: time.Now().UTC().Add(time.Duration(tokenRes.ExpiresIn) * time.Second), - }, nil + } + if tokenRes.IDToken != "" { + token = token.WithExtra(map[string]any{ + "id_token": tokenRes.IDToken, + }) + } + return token, nil } type EndSessionCaller interface { diff --git a/pkg/client/integration_test.go b/pkg/client/integration_test.go index 2c3ef623..d8b3f255 100644 --- a/pkg/client/integration_test.go +++ b/pkg/client/integration_test.go @@ -68,6 +68,7 @@ func TestRelyingPartySession(t *testing.T) { t.Logf("new token type %s", newTokens.TokenType) t.Logf("new expiry %s", newTokens.Expiry.Format(time.RFC3339)) require.NotEmpty(t, newTokens.AccessToken, "new accessToken") + assert.NotEmpty(t, newTokens.Extra("id_token"), "new idToken") t.Log("------ end session (logout) ------") @@ -158,7 +159,6 @@ func TestResourceServerTokenExchange(t *testing.T) { require.Error(t, err, "refresh token") assert.Contains(t, err.Error(), "subject_token is invalid") require.Nil(t, tokenExchangeResponse, "token exchange response") - } func RunAuthorizationCodeFlow(t *testing.T, opServer *httptest.Server, clientID, clientSecret string) (provider rp.RelyingParty, accessToken, refreshToken, idToken string) { diff --git a/pkg/client/profile/jwt_profile.go b/pkg/client/profile/jwt_profile.go index 668f749c..419f4175 100644 --- a/pkg/client/profile/jwt_profile.go +++ b/pkg/client/profile/jwt_profile.go @@ -17,7 +17,9 @@ type TokenSource interface { TokenCtx(context.Context) (*oauth2.Token, error) } -// jwtProfileTokenSource implements the TokenSource +// jwtProfileTokenSource implement the oauth2.TokenSource +// it will request a token using the OAuth2 JWT Profile Grant +// therefore sending an `assertion` by signing a JWT with the provided private key type jwtProfileTokenSource struct { clientID string audience []string diff --git a/pkg/client/rp/relying_party.go b/pkg/client/rp/relying_party.go index 820107f6..b93a373c 100644 --- a/pkg/client/rp/relying_party.go +++ b/pkg/client/rp/relying_party.go @@ -599,6 +599,10 @@ type RefreshTokenRequest struct { GrantType oidc.GrantType `schema:"grant_type"` } +// RefreshAccessToken performs a token refresh. If it doesn't error, it will always +// provide a new AccessToken. It may provide a new RefreshToken, and if it does, then +// the old one should be considered invalid. It may also provide a new IDToken. The +// new IDToken can be retrieved with token.Extra("id_token"). func RefreshAccessToken(ctx context.Context, rp RelyingParty, refreshToken, clientAssertion, clientAssertionType string) (*oauth2.Token, error) { request := RefreshTokenRequest{ RefreshToken: refreshToken, diff --git a/pkg/oidc/introspection_test.go b/pkg/oidc/introspection_test.go index bd498948..60cf8a47 100644 --- a/pkg/oidc/introspection_test.go +++ b/pkg/oidc/introspection_test.go @@ -4,6 +4,7 @@ import ( "encoding/json" "testing" + "github.com/muhlemmer/gu" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -25,7 +26,7 @@ func TestIntrospectionResponse_SetUserInfo(t *testing.T) { UserInfoProfile: userInfoData.UserInfoProfile, UserInfoEmail: userInfoData.UserInfoEmail, UserInfoPhone: userInfoData.UserInfoPhone, - Claims: userInfoData.Claims, + Claims: gu.MapCopy(userInfoData.Claims), }, }, { diff --git a/pkg/oidc/regression_assert_test.go b/pkg/oidc/regression_assert_test.go index 5e9fb3df..dd9f5ad7 100644 --- a/pkg/oidc/regression_assert_test.go +++ b/pkg/oidc/regression_assert_test.go @@ -6,6 +6,7 @@ import ( "encoding/json" "io" "os" + "reflect" "strings" "testing" @@ -38,10 +39,12 @@ func Test_assert_regression(t *testing.T) { assert.JSONEq(t, want, first) + target := reflect.New(reflect.TypeOf(obj).Elem()).Interface() + require.NoError(t, - json.Unmarshal([]byte(first), obj), + json.Unmarshal([]byte(first), target), ) - second, err := json.Marshal(obj) + second, err := json.Marshal(target) require.NoError(t, err) assert.JSONEq(t, want, string(second)) diff --git a/pkg/oidc/token.go b/pkg/oidc/token.go index 83f3805d..c02eaf4b 100644 --- a/pkg/oidc/token.go +++ b/pkg/oidc/token.go @@ -8,6 +8,7 @@ import ( "golang.org/x/oauth2" "gopkg.in/square/go-jose.v2" + "github.com/muhlemmer/gu" "github.com/zitadel/oidc/v3/pkg/crypto" ) @@ -157,6 +158,21 @@ func (t *IDTokenClaims) SetUserInfo(i *UserInfo) { t.UserInfoEmail = i.UserInfoEmail t.UserInfoPhone = i.UserInfoPhone t.Address = i.Address + if t.Claims == nil { + t.Claims = make(map[string]any, len(t.Claims)) + } + gu.MapMerge(i.Claims, t.Claims) +} + +func (t *IDTokenClaims) GetUserInfo() *UserInfo { + return &UserInfo{ + Subject: t.Subject, + UserInfoProfile: t.UserInfoProfile, + UserInfoEmail: t.UserInfoEmail, + UserInfoPhone: t.UserInfoPhone, + Address: t.Address, + Claims: gu.MapCopy(t.Claims), + } } func NewIDTokenClaims(issuer, subject string, audience []string, expiration, authTime time.Time, nonce string, acr string, amr []string, clientID string, skew time.Duration) *IDTokenClaims { diff --git a/pkg/oidc/token_test.go b/pkg/oidc/token_test.go index 0d9874e9..ef1e77f8 100644 --- a/pkg/oidc/token_test.go +++ b/pkg/oidc/token_test.go @@ -181,6 +181,9 @@ func TestIDTokenClaims_SetUserInfo(t *testing.T) { UserInfoEmail: userInfoData.UserInfoEmail, UserInfoPhone: userInfoData.UserInfoPhone, Address: userInfoData.Address, + Claims: map[string]interface{}{ + "foo": "bar", + }, } var got IDTokenClaims @@ -225,3 +228,16 @@ func TestNewIDTokenClaims(t *testing.T) { assert.Equal(t, want, got) } + +func TestIDTokenClaims_GetUserInfo(t *testing.T) { + want := &UserInfo{ + Subject: idTokenData.Subject, + UserInfoProfile: idTokenData.UserInfoProfile, + UserInfoEmail: idTokenData.UserInfoEmail, + UserInfoPhone: idTokenData.UserInfoPhone, + Address: idTokenData.Address, + Claims: idTokenData.Claims, + } + got := idTokenData.GetUserInfo() + assert.Equal(t, want, got) +} diff --git a/pkg/oidc/userinfo_test.go b/pkg/oidc/userinfo_test.go index faab4e38..a574366d 100644 --- a/pkg/oidc/userinfo_test.go +++ b/pkg/oidc/userinfo_test.go @@ -52,11 +52,14 @@ func TestUserInfoMarshal(t *testing.T) { out := new(UserInfo) assert.NoError(t, json.Unmarshal(marshal, out)) - assert.Equal(t, userinfo, out) expected, err := json.Marshal(out) assert.NoError(t, err) assert.Equal(t, expected, marshal) + + out2 := new(UserInfo) + assert.NoError(t, json.Unmarshal(expected, out2)) + assert.Equal(t, out, out2) } func TestUserInfoEmailVerifiedUnmarshal(t *testing.T) { diff --git a/pkg/oidc/util.go b/pkg/oidc/util.go index a89d75ee..462ea447 100644 --- a/pkg/oidc/util.go +++ b/pkg/oidc/util.go @@ -9,7 +9,7 @@ import ( // mergeAndMarshalClaims merges registered and the custom // claims map into a single JSON object. // Registered fields overwrite custom claims. -func mergeAndMarshalClaims(registered any, claims map[string]any) ([]byte, error) { +func mergeAndMarshalClaims(registered any, extraClaims map[string]any) ([]byte, error) { // Use a buffer for memory re-use, instead off letting // json allocate a new []byte for every step. buf := new(bytes.Buffer) @@ -19,16 +19,21 @@ func mergeAndMarshalClaims(registered any, claims map[string]any) ([]byte, error return nil, fmt.Errorf("oidc registered claims: %w", err) } - if len(claims) > 0 { + if len(extraClaims) > 0 { + merged := make(map[string]any) + for k, v := range extraClaims { + merged[k] = v + } + // Merge JSON data into custom claims. // The full-read action by the decoder resets the buffer // to zero len, while retaining underlaying cap. - if err := json.NewDecoder(buf).Decode(&claims); err != nil { + if err := json.NewDecoder(buf).Decode(&merged); err != nil { return nil, fmt.Errorf("oidc registered claims: %w", err) } // Marshal the final result. - if err := json.NewEncoder(buf).Encode(claims); err != nil { + if err := json.NewEncoder(buf).Encode(merged); err != nil { return nil, fmt.Errorf("oidc custom claims: %w", err) } } diff --git a/pkg/op/auth_request.go b/pkg/op/auth_request.go index b516909b..7af3779e 100644 --- a/pkg/op/auth_request.go +++ b/pkg/op/auth_request.go @@ -67,7 +67,7 @@ func authorizeCallbackHandler(authorizer Authorizer) func(http.ResponseWriter, * func Authorize(w http.ResponseWriter, r *http.Request, authorizer Authorizer) { authReq, err := ParseAuthorizeRequest(r, authorizer.Decoder()) if err != nil { - AuthRequestError(w, r, authReq, err, authorizer.Encoder()) + AuthRequestError(w, r, nil, err, authorizer.Encoder()) return } ctx := r.Context() @@ -273,9 +273,9 @@ func ValidateAuthReqScopes(client Client, scopes []string) ([]string, error) { return scopes, nil } -// checkURIAginstRedirects just checks aginst the valid redirect URIs and ignores +// checkURIAgainstRedirects just checks aginst the valid redirect URIs and ignores // other factors. -func checkURIAginstRedirects(client Client, uri string) error { +func checkURIAgainstRedirects(client Client, uri string) error { if str.Contains(client.RedirectURIs(), uri) { return nil } @@ -302,12 +302,12 @@ func ValidateAuthReqRedirectURI(client Client, uri string, responseType oidc.Res "Please ensure it is added to the request. If you have any questions, you may contact the administrator of the application.") } if strings.HasPrefix(uri, "https://") { - return checkURIAginstRedirects(client, uri) + return checkURIAgainstRedirects(client, uri) } if client.ApplicationType() == ApplicationTypeNative { return validateAuthReqRedirectURINative(client, uri, responseType) } - if err := checkURIAginstRedirects(client, uri); err != nil { + if err := checkURIAgainstRedirects(client, uri); err != nil { return err } if strings.HasPrefix(uri, "http://") { @@ -328,7 +328,7 @@ func ValidateAuthReqRedirectURI(client Client, uri string, responseType oidc.Res func validateAuthReqRedirectURINative(client Client, uri string, responseType oidc.ResponseType) error { parsedURL, isLoopback := HTTPLoopbackOrLocalhost(uri) isCustomSchema := !strings.HasPrefix(uri, "http://") - if err := checkURIAginstRedirects(client, uri); err == nil { + if err := checkURIAgainstRedirects(client, uri); err == nil { if client.DevMode() { return nil } diff --git a/pkg/op/auth_request_test.go b/pkg/op/auth_request_test.go index 4e801796..df340b6b 100644 --- a/pkg/op/auth_request_test.go +++ b/pkg/op/auth_request_test.go @@ -9,6 +9,7 @@ import ( "reflect" "testing" + "github.com/golang/mock/gomock" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" tu "github.com/zitadel/oidc/v3/internal/testutil" @@ -19,60 +20,34 @@ import ( "github.com/zitadel/schema" ) -// -// TOOD: tests will be implemented in branch for service accounts -// func TestAuthorize(t *testing.T) { -// // testCallback := func(t *testing.T, clienID string) callbackHandler { -// // return func(authReq *oidc.AuthRequest, client oidc.Client, w http.ResponseWriter, r *http.Request) { -// // // require.Equal(t, clientID, client.) -// // } -// // } -// // testErr := func(t *testing.T, expected error) errorHandler { -// // return func(w http.ResponseWriter, r *http.Request, authReq *oidc.AuthRequest, err error) { -// // require.Equal(t, expected, err) -// // } -// // } -// type args struct { -// w http.ResponseWriter -// r *http.Request -// authorizer op.Authorizer -// } -// tests := []struct { -// name string -// args args -// }{ -// { -// "parsing fails", -// args{ -// httptest.NewRecorder(), -// &http.Request{Method: "POST", Body: nil}, -// mock.NewAuthorizerExpectValid(t, true), -// // testCallback(t, ""), -// // testErr(t, ErrInvalidRequest("cannot parse form")), -// }, -// }, -// { -// "decoding fails", -// args{ -// httptest.NewRecorder(), -// func() *http.Request { -// r := httptest.NewRequest("POST", "/authorize", strings.NewReader("client_id=foo")) -// r.Header.Set("Content-Type", "application/x-www-form-urlencoded") -// return r -// }(), -// mock.NewAuthorizerExpectValid(t, true), -// // testCallback(t, ""), -// // testErr(t, ErrInvalidRequest("cannot parse auth request")), -// }, -// }, -// // {"decoding fails", args{httptest.NewRecorder(), &http.Request{}, mock.NewAuthorizerExpectValid(t), nil, testErr(t, nil)}}, -// } -// for _, tt := range tests { -// t.Run(tt.name, func(t *testing.T) { -// op.Authorize(tt.args.w, tt.args.r, tt.args.authorizer) -// }) -// } -//} +func TestAuthorize(t *testing.T) { + tests := []struct { + name string + req *http.Request + expect func(a *mock.MockAuthorizerMockRecorder) + }{ + { + name: "parse error", // used to panic, see issue #315 + req: httptest.NewRequest(http.MethodPost, "/?;", nil), + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + w := httptest.NewRecorder() + authorizer := mock.NewMockAuthorizer(gomock.NewController(t)) + + expect := authorizer.EXPECT() + expect.Decoder().Return(schema.NewDecoder()) + expect.Encoder().Return(schema.NewEncoder()) + + if tt.expect != nil { + tt.expect(expect) + } + + op.Authorize(w, tt.req, authorizer) + }) + } +} func TestParseAuthorizeRequest(t *testing.T) { type args struct { diff --git a/pkg/op/client.go b/pkg/op/client.go index 754636cc..d01845f2 100644 --- a/pkg/op/client.go +++ b/pkg/op/client.go @@ -56,6 +56,12 @@ type Client interface { // interpretation. Redirect URIs that match either the non-glob version or the // glob version will be accepted. Glob URIs are only partially supported for native // clients: "http://" is not allowed except for loopback or in dev mode. +// +// Note that globbing / wildcards are not permitted by the OIDC +// standard and implementing this interface can have security implications. +// It is advised to only return a client of this type in rare cases, +// such as DevMode for the client being enabled. +// https://openid.net/specs/openid-connect-core-1_0.html#AuthRequest type HasRedirectGlobs interface { RedirectURIGlobs() []string PostLogoutRedirectURIGlobs() []string @@ -145,21 +151,30 @@ func ClientIDFromRequest(r *http.Request, p ClientProvider) (clientID string, au } data := new(clientData) - if err = p.Decoder().Decode(data, r.PostForm); err != nil { + if err = p.Decoder().Decode(data, r.Form); err != nil { return "", false, err } JWTProfile, ok := p.(ClientJWTProfile) - if ok { + if ok && data.ClientAssertion != "" { + // if JWTProfile is supported and client sent an assertion, check it and use it as response + // regardless if it succeeded or failed clientID, err = ClientJWTAuth(r.Context(), data.ClientAssertionParams, JWTProfile) + return clientID, err == nil, err } - if !ok || errors.Is(err, ErrNoClientCredentials) { - clientID, err = ClientBasicAuth(r, p.Storage()) - } + // try basic auth + clientID, err = ClientBasicAuth(r, p.Storage()) + // if that succeeded, use it if err == nil { return clientID, true, nil } + // if the client did not send a Basic Auth Header, ignore the `ErrNoClientCredentials` + // but return other errors immediately + if err != nil && !errors.Is(err, ErrNoClientCredentials) { + return "", false, err + } + // if the client did not authenticate (public clients) it must at least send a client_id if data.ClientID == "" { return "", false, oidc.ErrInvalidClient().WithParent(ErrMissingClientID) } diff --git a/pkg/op/device.go b/pkg/op/device.go index e54da706..09c7fca1 100644 --- a/pkg/op/device.go +++ b/pkg/op/device.go @@ -8,6 +8,7 @@ import ( "fmt" "math/big" "net/http" + "net/url" "strings" "time" @@ -18,7 +19,14 @@ import ( type DeviceAuthorizationConfig struct { Lifetime time.Duration PollInterval time.Duration - UserFormURL string // the URL where the user must go to authorize the device + + // UserFormURL is the complete URL where the user must go to authorize the device. + // Deprecated: use UserFormPath instead. + UserFormURL string + + // UserFormPath is the path where the user must go to authorize the device. + // The hostname for the URL is taken from the request by IssuerFromContext. + UserFormPath string UserCode UserCodeConfig } @@ -82,15 +90,28 @@ func DeviceAuthorization(w http.ResponseWriter, r *http.Request, o OpenIDProvide return err } + var verification *url.URL + if config.UserFormURL != "" { + if verification, err = url.Parse(config.UserFormURL); err != nil { + return oidc.ErrServerError().WithParent(err).WithDescription("invalid URL for device user form") + } + } else { + if verification, err = url.Parse(IssuerFromContext(r.Context())); err != nil { + return oidc.ErrServerError().WithParent(err).WithDescription("invalid URL for issuer") + } + verification.Path = config.UserFormPath + } + response := &oidc.DeviceAuthorizationResponse{ DeviceCode: deviceCode, UserCode: userCode, - VerificationURI: config.UserFormURL, + VerificationURI: verification.String(), ExpiresIn: int(config.Lifetime / time.Second), Interval: int(config.PollInterval / time.Second), } - response.VerificationURIComplete = fmt.Sprintf("%s?user_code=%s", config.UserFormURL, userCode) + verification.RawQuery = "user_code=" + userCode + response.VerificationURIComplete = verification.String() httphelper.MarshalJSON(w, response) return nil diff --git a/pkg/op/device_test.go b/pkg/op/device_test.go index ab117002..1e32554b 100644 --- a/pkg/op/device_test.go +++ b/pkg/op/device_test.go @@ -13,6 +13,7 @@ import ( "testing" "time" + "github.com/muhlemmer/gu" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/zitadel/oidc/v3/pkg/oidc" @@ -20,29 +21,60 @@ import ( ) func Test_deviceAuthorizationHandler(t *testing.T) { - req := &oidc.DeviceAuthorizationRequest{ - Scopes: []string{"foo", "bar"}, - ClientID: "web", + type conf struct { + UserFormURL string + UserFormPath string } - values := make(url.Values) - testProvider.Encoder().Encode(req, values) - body := strings.NewReader(values.Encode()) + tests := []struct { + name string + conf conf + }{ + { + name: "UserFormURL", + conf: conf{ + UserFormURL: "https://localhost:9998/device", + }, + }, + { + name: "UserFormPath", + conf: conf{ + UserFormPath: "/device", + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + conf := gu.PtrCopy(testConfig) + conf.DeviceAuthorization.UserFormURL = tt.conf.UserFormURL + conf.DeviceAuthorization.UserFormPath = tt.conf.UserFormPath + provider := newTestProvider(conf) - r := httptest.NewRequest(http.MethodPost, "/", body) - r.Header.Set("Content-Type", "application/x-www-form-urlencoded") + req := &oidc.DeviceAuthorizationRequest{ + Scopes: []string{"foo", "bar"}, + ClientID: "web", + } + values := make(url.Values) + testProvider.Encoder().Encode(req, values) + body := strings.NewReader(values.Encode()) - w := httptest.NewRecorder() + r := httptest.NewRequest(http.MethodPost, "/", body) + r.Header.Set("Content-Type", "application/x-www-form-urlencoded") + r = r.WithContext(op.ContextWithIssuer(r.Context(), testIssuer)) - runWithRandReader(mr.New(mr.NewSource(1)), func() { - op.DeviceAuthorizationHandler(testProvider)(w, r) - }) + w := httptest.NewRecorder() - result := w.Result() + runWithRandReader(mr.New(mr.NewSource(1)), func() { + op.DeviceAuthorizationHandler(provider)(w, r) + }) - assert.Less(t, result.StatusCode, 300) + result := w.Result() - got, _ := io.ReadAll(result.Body) - assert.JSONEq(t, `{"device_code":"Uv38ByGCZU8WP18PmmIdcg", "expires_in":300, "interval":5, "user_code":"JKRV-FRGK", "verification_uri":"https://localhost:9998/device", "verification_uri_complete":"https://localhost:9998/device?user_code=JKRV-FRGK"}`, string(got)) + assert.Less(t, result.StatusCode, 300) + + got, _ := io.ReadAll(result.Body) + assert.JSONEq(t, `{"device_code":"Uv38ByGCZU8WP18PmmIdcg", "expires_in":300, "interval":5, "user_code":"JKRV-FRGK", "verification_uri":"https://localhost:9998/device", "verification_uri_complete":"https://localhost:9998/device?user_code=JKRV-FRGK"}`, string(got)) + }) + } } func TestParseDeviceCodeRequest(t *testing.T) { diff --git a/pkg/op/op.go b/pkg/op/op.go index 1cdb3bc9..1fbe7801 100644 --- a/pkg/op/op.go +++ b/pkg/op/op.go @@ -480,6 +480,16 @@ func WithCustomKeysEndpoint(endpoint Endpoint) Option { } } +func WithCustomDeviceAuthorizationEndpoint(endpoint Endpoint) Option { + return func(o *Provider) error { + if err := endpoint.Validate(); err != nil { + return err + } + o.endpoints.DeviceAuthorization = endpoint + return nil + } +} + func WithCustomEndpoints(auth, token, userInfo, revocation, endSession, keys Endpoint) Option { return func(o *Provider) error { o.endpoints.Authorization = auth diff --git a/pkg/op/op_test.go b/pkg/op/op_test.go index 3958b89b..d347d048 100644 --- a/pkg/op/op_test.go +++ b/pkg/op/op_test.go @@ -20,15 +20,9 @@ import ( "golang.org/x/text/language" ) -var testProvider op.OpenIDProvider - -const ( - testIssuer = "https://localhost:9998/" - pathLoggedOut = "/logged-out" -) - -func init() { - config := &op.Config{ +var ( + testProvider op.OpenIDProvider + testConfig = &op.Config{ CryptoKey: sha256.Sum256([]byte("test")), DefaultLogoutRedirectURI: pathLoggedOut, CodeMethodS256: true, @@ -40,24 +34,35 @@ func init() { DeviceAuthorization: op.DeviceAuthorizationConfig{ Lifetime: 5 * time.Minute, PollInterval: 5 * time.Second, - UserFormURL: testIssuer + "device", + UserFormPath: "/device", UserCode: op.UserCodeBase20, }, } +) +const ( + testIssuer = "https://localhost:9998/" + pathLoggedOut = "/logged-out" +) + +func init() { storage.RegisterClients( storage.NativeClient("native"), storage.WebClient("web", "secret", "https://example.com"), storage.WebClient("api", "secret"), ) - var err error - testProvider, err = op.NewOpenIDProvider(testIssuer, config, + testProvider = newTestProvider(testConfig) +} + +func newTestProvider(config *op.Config) op.OpenIDProvider { + provider, err := op.NewOpenIDProvider(testIssuer, config, storage.NewStorage(storage.NewUserStore(testIssuer)), op.WithAllowInsecure(), ) if err != nil { panic(err) } + return provider } type routesTestStorage interface { diff --git a/pkg/op/storage.go b/pkg/op/storage.go index 25444ddb..aa8721ac 100644 --- a/pkg/op/storage.go +++ b/pkg/op/storage.go @@ -113,6 +113,8 @@ type OPStorage interface { // handle the current request. GetClientByClientID(ctx context.Context, clientID string) (Client, error) AuthorizeClientIDSecret(ctx context.Context, clientID, clientSecret string) error + // SetUserinfoFromScopes is deprecated and should have an empty implementation for now. + // Implement SetUserinfoFromRequest instead. SetUserinfoFromScopes(ctx context.Context, userinfo *oidc.UserInfo, userID, clientID string, scopes []string) error SetUserinfoFromToken(ctx context.Context, userinfo *oidc.UserInfo, tokenID, subject, origin string) error SetIntrospectionFromToken(ctx context.Context, userinfo *oidc.IntrospectionResponse, tokenID, subject, clientID string) error @@ -127,6 +129,13 @@ type JWTProfileTokenStorage interface { JWTProfileTokenType(ctx context.Context, request TokenRequest) (AccessTokenType, error) } +// CanSetUserinfoFromRequest is an optional additional interface that may be implemented by +// implementors of Storage. It allows additional data to be set in id_tokens based on the +// request. +type CanSetUserinfoFromRequest interface { + SetUserinfoFromRequest(ctx context.Context, userinfo *oidc.UserInfo, request IDTokenRequest, scopes []string) error +} + // Storage is a required parameter for NewOpenIDProvider(). In addition to the // embedded interfaces below, if the passed Storage implements ClientCredentialsStorage // then the grant type "client_credentials" will be supported. In that case, the access diff --git a/pkg/op/token.go b/pkg/op/token.go index 44648aac..22f67c4c 100644 --- a/pkg/op/token.go +++ b/pkg/op/token.go @@ -190,6 +190,12 @@ func CreateIDToken(ctx context.Context, issuer string, request IDTokenRequest, v if err != nil { return "", err } + if fromRequest, ok := storage.(CanSetUserinfoFromRequest); ok { + err := fromRequest.SetUserinfoFromRequest(ctx, userInfo, request, scopes) + if err != nil { + return "", err + } + } claims.SetUserInfo(userInfo) } if code != "" {