From fe60421e11155da642692c36b7bbc361e1310c3d Mon Sep 17 00:00:00 2001 From: Sean McGrail Date: Tue, 26 Jan 2021 10:16:03 -0800 Subject: [PATCH] Expand Credential Chain Tests --- config/resolve_credentials.go | 6 +- config/resolve_credentials_test.go | 188 ++++++++++++++++++++------ config/shared_config.go | 8 +- config/shared_config_test.go | 32 ++++- config/shared_test.go | 14 ++ config/testdata/config_source_shared | 24 ++++ config/testdata/shared_config | 21 +++ credentials/ssocreds/provider.go | 10 +- credentials/ssocreds/provider_test.go | 16 +-- internal/sdk/time.go | 11 ++ 10 files changed, 266 insertions(+), 64 deletions(-) diff --git a/config/resolve_credentials.go b/config/resolve_credentials.go index c74d1ef83d5..6e712c66764 100644 --- a/config/resolve_credentials.go +++ b/config/resolve_credentials.go @@ -110,9 +110,6 @@ func resolveCredentialChain(ctx context.Context, cfg *aws.Config, configs config func resolveCredsFromProfile(ctx context.Context, cfg *aws.Config, envConfig *EnvConfig, sharedConfig *SharedConfig, configs configs) (err error) { switch { - case sharedConfig.hasSSOConfiguration(): - err = resolveSSOCredentials(ctx, cfg, sharedConfig, configs) - case sharedConfig.Source != nil: // Assume IAM role with credentials source from a different profile. err = resolveCredsFromProfile(ctx, cfg, envConfig, sharedConfig.Source, configs) @@ -123,6 +120,9 @@ func resolveCredsFromProfile(ctx context.Context, cfg *aws.Config, envConfig *En Value: sharedConfig.Credentials, } + case sharedConfig.hasSSOConfiguration(): + err = resolveSSOCredentials(ctx, cfg, sharedConfig, configs) + case len(sharedConfig.CredentialProcess) != 0: // Get credentials from CredentialProcess err = processCredentials(ctx, cfg, sharedConfig, configs) diff --git a/config/resolve_credentials_test.go b/config/resolve_credentials_test.go index b1b8fcadbc6..aefc3fe2c79 100644 --- a/config/resolve_credentials_test.go +++ b/config/resolve_credentials_test.go @@ -3,6 +3,7 @@ package config import ( "context" "fmt" + "io/ioutil" "net/http" "net/http/httptest" "os" @@ -15,6 +16,7 @@ import ( "github.com/aws/aws-sdk-go-v2/aws" "github.com/aws/aws-sdk-go-v2/internal/awstesting" + "github.com/aws/aws-sdk-go-v2/service/sso" "github.com/aws/aws-sdk-go-v2/service/sts" "github.com/aws/smithy-go/middleware" ) @@ -66,6 +68,14 @@ func setupCredentialsEndpoints(t *testing.T) (aws.EndpointResolver, func()) { Format("2006-01-02T15:04:05Z")))) })) + ssoServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Write([]byte(fmt.Sprintf( + getRoleCredentialsResponse, + time.Now(). + Add(15*time.Minute). + UnixNano()/int64(time.Millisecond)))) + })) + resolver := aws.EndpointResolverFunc( func(service, region string) (aws.Endpoint, error) { switch service { @@ -73,6 +83,10 @@ func setupCredentialsEndpoints(t *testing.T) (aws.EndpointResolver, func()) { return aws.Endpoint{ URL: stsServer.URL, }, nil + case sso.ServiceID: + return aws.Endpoint{ + URL: ssoServer.URL, + }, nil default: return aws.Endpoint{}, fmt.Errorf("unknown service endpoint, %s", service) @@ -87,6 +101,44 @@ func setupCredentialsEndpoints(t *testing.T) (aws.EndpointResolver, func()) { } } +func ssoTestSetup() (func(), error) { + dir, err := ioutil.TempDir("", "sso-test") + if err != nil { + return nil, err + } + + cacheDir := filepath.Join(dir, ".aws", "sso", "cache") + err = os.MkdirAll(cacheDir, 0750) + if err != nil { + os.RemoveAll(dir) + return nil, err + } + + tokenFile, err := os.Create(filepath.Join(cacheDir, "eb5e43e71ce87dd92ec58903d76debd8ee42aefd.json")) + if err != nil { + os.RemoveAll(dir) + return nil, err + } + defer tokenFile.Close() + + _, err = tokenFile.WriteString(fmt.Sprintf(ssoTokenCacheFile, time.Now(). + Add(15*time.Minute). + Format(time.RFC3339))) + if err != nil { + os.RemoveAll(dir) + return nil, err + } + + if runtime.GOOS == "windows" { + os.Setenv("USERPROFILE", dir) + } else { + os.Setenv("HOME", dir) + } + + return func() { + }, nil +} + func TestSharedConfigCredentialSource(t *testing.T) { var configFileForWindows = filepath.Join("testdata", "config_source_shared_for_windows") var configFile = filepath.Join("testdata", "config_source_shared") @@ -95,34 +147,38 @@ func TestSharedConfigCredentialSource(t *testing.T) { var credFile = filepath.Join("testdata", "credentials_source_shared") cases := map[string]struct { - name string - envProfile string - configProfile string - expectedError string - expectedAccessKey string - expectedSecretKey string - expectedChain []string - init func() - dependentOnOS bool + name string + envProfile string + configProfile string + expectedError string + expectedAccessKey string + expectedSecretKey string + expectedSessionToken string + expectedChain []string + init func() (func(), error) + dependentOnOS bool }{ "credential source and source profile": { envProfile: "invalid_source_and_credential_source", expectedError: "only one credential type may be specified per profile", - init: func() { + init: func() (func(), error) { os.Setenv("AWS_ACCESS_KEY", "access_key") os.Setenv("AWS_SECRET_KEY", "secret_key") + return func() {}, nil }, }, "env var credential source": { - configProfile: "env_var_credential_source", - expectedAccessKey: "AKID", - expectedSecretKey: "SECRET", + configProfile: "env_var_credential_source", + expectedAccessKey: "AKID", + expectedSecretKey: "SECRET", + expectedSessionToken: "SESSION_TOKEN", expectedChain: []string{ "assume_role_w_creds_role_arn_env", }, - init: func() { + init: func() (func(), error) { os.Setenv("AWS_ACCESS_KEY", "access_key") os.Setenv("AWS_SECRET_KEY", "secret_key") + return func() {}, nil }, }, "ec2metadata credential source": { @@ -130,24 +186,28 @@ func TestSharedConfigCredentialSource(t *testing.T) { expectedChain: []string{ "assume_role_w_creds_role_arn_ec2", }, - expectedAccessKey: "AKID", - expectedSecretKey: "SECRET", + expectedAccessKey: "AKID", + expectedSecretKey: "SECRET", + expectedSessionToken: "SESSION_TOKEN", }, "ecs container credential source": { - envProfile: "ecscontainer", - expectedAccessKey: "AKID", - expectedSecretKey: "SECRET", + envProfile: "ecscontainer", + expectedAccessKey: "AKID", + expectedSecretKey: "SECRET", + expectedSessionToken: "SESSION_TOKEN", expectedChain: []string{ "assume_role_w_creds_role_arn_ecs", }, - init: func() { + init: func() (func(), error) { os.Setenv("AWS_CONTAINER_CREDENTIALS_RELATIVE_URI", "/ECS") + return func() {}, nil }, }, "chained assume role with env creds": { - envProfile: "chained_assume_role", - expectedAccessKey: "AKID", - expectedSecretKey: "SECRET", + envProfile: "chained_assume_role", + expectedAccessKey: "AKID", + expectedSecretKey: "SECRET", + expectedSessionToken: "SESSION_TOKEN", expectedChain: []string{ "assume_role_w_creds_role_arn_chain", "assume_role_w_creds_role_arn_ec2", @@ -160,43 +220,79 @@ func TestSharedConfigCredentialSource(t *testing.T) { expectedSecretKey: "cred_proc_secret", }, "credential process with ARN set": { - envProfile: "cred_proc_arn_set", - dependentOnOS: true, - expectedAccessKey: "AKID", - expectedSecretKey: "SECRET", + envProfile: "cred_proc_arn_set", + dependentOnOS: true, + expectedAccessKey: "AKID", + expectedSecretKey: "SECRET", + expectedSessionToken: "SESSION_TOKEN", expectedChain: []string{ "assume_role_w_creds_proc_role_arn", }, }, "chained assume role with credential process": { - envProfile: "chained_cred_proc", - dependentOnOS: true, - expectedAccessKey: "AKID", - expectedSecretKey: "SECRET", + envProfile: "chained_cred_proc", + dependentOnOS: true, + expectedAccessKey: "AKID", + expectedSecretKey: "SECRET", + expectedSessionToken: "SESSION_TOKEN", expectedChain: []string{ "assume_role_w_creds_proc_source_prof", }, }, "credential source overrides config source": { - envProfile: "credentials_overide", - expectedAccessKey: "AKID", - expectedSecretKey: "SECRET", + envProfile: "credentials_overide", + expectedAccessKey: "AKID", + expectedSecretKey: "SECRET", + expectedSessionToken: "SESSION_TOKEN", expectedChain: []string{ "assume_role_w_creds_role_arn_ec2", }, - init: func() { + init: func() (func(), error) { os.Setenv("AWS_CONTAINER_CREDENTIALS_RELATIVE_URI", "/ECS") + return func() {}, nil }, }, "only credential source": { - envProfile: "only_credentials_source", - expectedAccessKey: "AKID", - expectedSecretKey: "SECRET", + envProfile: "only_credentials_source", + expectedAccessKey: "AKID", + expectedSecretKey: "SECRET", + expectedSessionToken: "SESSION_TOKEN", expectedChain: []string{ "assume_role_w_creds_role_arn_ecs", }, - init: func() { + init: func() (func(), error) { os.Setenv("AWS_CONTAINER_CREDENTIALS_RELATIVE_URI", "/ECS") + return func() {}, nil + }, + }, + "sso credentials": { + envProfile: "sso_creds", + expectedAccessKey: "SSO_AKID", + expectedSecretKey: "SSO_SECRET_KEY", + expectedSessionToken: "SSO_SESSION_TOKEN", + init: func() (func(), error) { + return ssoTestSetup() + }, + }, + "chained assume role with sso credentials": { + envProfile: "source_sso_creds", + expectedAccessKey: "AKID", + expectedSecretKey: "SECRET", + expectedSessionToken: "SESSION_TOKEN", + expectedChain: []string{ + "source_sso_creds_arn", + }, + init: func() (func(), error) { + return ssoTestSetup() + }, + }, + "chained assume role with sso and static credentials": { + envProfile: "assume_sso_and_static", + expectedAccessKey: "AKID", + expectedSecretKey: "SECRET", + expectedSessionToken: "SESSION_TOKEN", + expectedChain: []string{ + "assume_sso_and_static_arn", }, }, } @@ -222,8 +318,14 @@ func TestSharedConfigCredentialSource(t *testing.T) { endpointResolver, cleanupFn := setupCredentialsEndpoints(t) defer cleanupFn() + var cleanup func() if c.init != nil { - c.init() + var err error + cleanup, err = c.init() + if err != nil { + t.Fatalf("expect no error, got %v", err) + } + defer cleanup() } var credChain []string @@ -278,7 +380,11 @@ func TestSharedConfigCredentialSource(t *testing.T) { } if e, a := c.expectedSecretKey, creds.SecretAccessKey; e != a { - t.Errorf("expected %v, but received %v", e, a) + t.Errorf("expect %v, but received %v", e, a) + } + + if e, a := c.expectedSessionToken, creds.SessionToken; e != a { + t.Errorf("expect %v, got %v", e, a) } }) } diff --git a/config/shared_config.go b/config/shared_config.go index 1a715959e47..c3a0c97a436 100644 --- a/config/shared_config.go +++ b/config/shared_config.go @@ -759,15 +759,15 @@ func (c *SharedConfig) setFromIniSections(profiles map[string]struct{}, profile c.clearAssumeRoleOptions() } else { // First time a profile has been seen, It must either be a assume role - // or credentials. Assert if the credential type requires a role ARN, - // the ARN is also set. + // credentials, or SSO. Assert if the credential type requires a role ARN, + // the ARN is also set, or validate that the SSO configuration is complete. if err := c.validateCredentialsConfig(profile); err != nil { return err } } // if not top level profile and has credentials, return with credentials. - if len(profiles) != 0 && (c.Credentials.HasKeys() || c.hasSSOConfiguration()) { + if len(profiles) != 0 && c.Credentials.HasKeys() { return nil } @@ -951,7 +951,7 @@ func (c *SharedConfig) validateSSOConfiguration(profile string) error { if len(missing) > 0 { return fmt.Errorf("profile %q is configured to use SSO but is missing required configuration: %s", - profile, strings.Join(missing, ",")) + profile, strings.Join(missing, ", ")) } return nil diff --git a/config/shared_config_test.go b/config/shared_config_test.go index 6087c571eaa..fcc2083b6df 100644 --- a/config/shared_config_test.go +++ b/config/shared_config_test.go @@ -14,6 +14,7 @@ import ( "github.com/aws/aws-sdk-go-v2/internal/ini" "github.com/aws/smithy-go/logging" "github.com/aws/smithy-go/ptr" + "github.com/google/go-cmp/cmp" ) var _ regionProvider = (*SharedConfig)(nil) @@ -245,6 +246,33 @@ func TestNewSharedConfig(t *testing.T) { }, }, }, + "AWS SSO Invalid Profile": { + Filenames: []string{testConfigFilename}, + Profile: "invalid_sso_creds", + Err: fmt.Errorf("profile \"invalid_sso_creds\" is configured to use SSO but is missing required configuration: sso_region, sso_role_name, sso_start_url"), + }, + "AWS SSO Profile and Static Credentials": { + Filenames: []string{testConfigFilename}, + Profile: "sso_and_static", + Expected: SharedConfig{ + Profile: "sso_and_static", + Credentials: aws.Credentials{ + AccessKeyID: "sso_and_static_akid", + SecretAccessKey: "sso_and_static_secret", + SessionToken: "sso_and_static_token", + Source: "SharedConfigCredentials: testdata/shared_config", + }, + SSOAccountID: "012345678901", + SSORegion: "us-west-2", + SSORoleName: "TestRole", + SSOStartURL: "https://127.0.0.1/start", + }, + }, + "Assume Role with AWS SSO Configuration and Source Profile": { + Filenames: []string{testConfigFilename}, + Profile: "source_sso_and_assume", + Err: fmt.Errorf("only one credential type may be specified per profile"), + }, } for name, c := range cases { @@ -265,8 +293,8 @@ func TestNewSharedConfig(t *testing.T) { if c.Err != nil { t.Errorf("expect error: %v, got none", c.Err) } - if e, a := c.Expected, cfg; !reflect.DeepEqual(e, a) { - t.Errorf(" expect %v, got %v", e, a) + if diff := cmp.Diff(c.Expected, cfg); len(diff) > 0 { + t.Error(diff) } }) } diff --git a/config/shared_test.go b/config/shared_test.go index 2607e3440fe..ebe91923385 100644 --- a/config/shared_test.go +++ b/config/shared_test.go @@ -47,6 +47,20 @@ const assumeRoleRespMsg = ` ` +const getRoleCredentialsResponse = `{ + "roleCredentials": { + "accessKeyId": "SSO_AKID", + "secretAccessKey": "SSO_SECRET_KEY", + "sessionToken": "SSO_SESSION_TOKEN", + "expiration": %d + } +}` + +const ssoTokenCacheFile = `{ + "accessToken": "ssoAccessToken", + "expiresAt": "%s" +}` + type mockHTTPClient func(*http.Request) (*http.Response, error) func (m mockHTTPClient) Do(r *http.Request) (*http.Response, error) { diff --git a/config/testdata/config_source_shared b/config/testdata/config_source_shared index e94eb688a7a..980a7c10470 100644 --- a/config/testdata/config_source_shared +++ b/config/testdata/config_source_shared @@ -33,3 +33,27 @@ source_profile = cred_proc_no_arn_set [profile credentials_overide] role_arn = assume_role_w_creds_role_arn_ec2 credential_source = Ec2InstanceMetadata + +[profile sso_creds] +sso_account_id = 012345678901 +sso_region = us-west-2 +sso_role_name = TestRole +sso_start_url = https://127.0.0.1/start + +[profile source_sso_creds] +role_arn = source_sso_creds_arn +source_profile = sso_creds + +[profile assume_sso_and_static] +role_arn = assume_sso_and_static_arn +source_profile = sso_and_static + +[profile sso_and_static] +aws_access_key_id = sso_and_static_akid +aws_secret_access_key = sso_and_static_secret +aws_session_token = sso_and_static_token +sso_account_id = 012345678901 +sso_region = us-west-2 +sso_role_name = TestRole +sso_start_url = https://127.0.0.1/start + diff --git a/config/testdata/shared_config b/config/testdata/shared_config index 578fd56e469..ff12f084730 100644 --- a/config/testdata/shared_config +++ b/config/testdata/shared_config @@ -117,3 +117,24 @@ source_profile = sso_creds [profile invalid_sso_creds] sso_account_id = 012345678901 + +[profile sso_and_static] +aws_access_key_id = sso_and_static_akid +aws_secret_access_key = sso_and_static_secret +aws_session_token = sso_and_static_token +sso_account_id = 012345678901 +sso_region = us-west-2 +sso_role_name = TestRole +sso_start_url = https://127.0.0.1/start + +[profile sso_and_assume] +sso_account_id = 012345678901 +sso_region = us-west-2 +sso_role_name = TestRole +sso_start_url = https://127.0.0.1/start +role_arn = sso_with_assume_role_arn +source_profile = multiple_assume_role_with_credential_source + +[profile source_sso_and_assume] +role_arn = source_sso_and_assume_arn +source_profile = sso_and_assume diff --git a/credentials/ssocreds/provider.go b/credentials/ssocreds/provider.go index 344f299c7c3..7a2f6aaea76 100644 --- a/credentials/ssocreds/provider.go +++ b/credentials/ssocreds/provider.go @@ -19,7 +19,13 @@ import ( // ProviderName is the name of the provider used to specify the source of credentials. const ProviderName = "SSOProvider" -var defaultCacheLocation = filepath.Join(getHomeDirectory(), ".aws", "sso", "cache") +var defaultCacheLocation func() string + +func init() { + defaultCacheLocation = func() string { + return filepath.Join(getHomeDirectory(), ".aws", "sso", "cache") + } +} // GetRoleCredentialsAPIClient is a API client that implements the GetRoleCredentials operation. type GetRoleCredentialsAPIClient interface { @@ -155,7 +161,7 @@ func loadTokenFile(startURL string) (t token, err error) { return token{}, &InvalidTokenError{Err: err} } - fileBytes, err := ioutil.ReadFile(filepath.Join(defaultCacheLocation, key)) + fileBytes, err := ioutil.ReadFile(filepath.Join(defaultCacheLocation(), key)) if err != nil { return token{}, &InvalidTokenError{Err: err} } diff --git a/credentials/ssocreds/provider_test.go b/credentials/ssocreds/provider_test.go index 18122b5ded9..cab42d98333 100644 --- a/credentials/ssocreds/provider_test.go +++ b/credentials/ssocreds/provider_test.go @@ -55,19 +55,11 @@ func (m mockClient) GetRoleCredentials(ctx context.Context, params *sso.GetRoleC func swapCacheLocation(dir string) func() { original := defaultCacheLocation - defaultCacheLocation = dir - return func() { - defaultCacheLocation = original - } -} - -func swapNowTime(referenceTime time.Time) func() { - original := sdk.NowTime - sdk.NowTime = func() time.Time { - return referenceTime + defaultCacheLocation = func() string { + return dir } return func() { - sdk.NowTime = original + defaultCacheLocation = original } } @@ -75,7 +67,7 @@ func TestProvider(t *testing.T) { restoreCache := swapCacheLocation("testdata") defer restoreCache() - restoreTime := swapNowTime(time.Date(2021, 01, 19, 19, 50, 0, 0, time.UTC)) + restoreTime := sdk.TestingUseReferenceTime(time.Date(2021, 01, 19, 19, 50, 0, 0, time.UTC)) defer restoreTime() cases := map[string]struct { diff --git a/internal/sdk/time.go b/internal/sdk/time.go index 7b1e5d92752..8e8dabad548 100644 --- a/internal/sdk/time.go +++ b/internal/sdk/time.go @@ -61,3 +61,14 @@ func TestingUseNopSleep() func() { Sleep = time.Sleep } } + +// TestingUseReferenceTime is a utility for swapping the time function across the SDK to return a specific reference time +// for testing purposes. +func TestingUseReferenceTime(referenceTime time.Time) func() { + NowTime = func() time.Time { + return referenceTime + } + return func() { + NowTime = time.Now + } +}