diff --git a/awsutil/clients.go b/awsutil/clients.go new file mode 100644 index 0000000..bb27f4c --- /dev/null +++ b/awsutil/clients.go @@ -0,0 +1,90 @@ +package awsutil + +import ( + "errors" + "fmt" + + "github.com/aws/aws-sdk-go/aws/session" + "github.com/aws/aws-sdk-go/service/iam" + "github.com/aws/aws-sdk-go/service/iam/iamiface" + "github.com/aws/aws-sdk-go/service/sts" + "github.com/aws/aws-sdk-go/service/sts/stsiface" +) + +// IAMAPIFunc is a factory function for returning an IAM interface, +// useful for supplying mock interfaces for testing IAM. The session +// is passed into the function in the same way as done with the +// standard iam.New() constructor. +type IAMAPIFunc func(sess *session.Session) (iamiface.IAMAPI, error) + +// STSAPIFunc is a factory function for returning a STS interface, +// useful for supplying mock interfaces for testing STS. The session +// is passed into the function in the same way as done with the +// standard sts.New() constructor. +type STSAPIFunc func(sess *session.Session) (stsiface.STSAPI, error) + +// IAMClient returns an IAM client. +// +// Supported options: WithSession, WithIAMAPIFunc. +// +// If WithIAMAPIFunc is supplied, the included function is used as +// the IAM client constructor instead. This can be used for Mocking +// the IAM API. +func (c *CredentialsConfig) IAMClient(opt ...Option) (iamiface.IAMAPI, error) { + opts, err := getOpts(opt...) + if err != nil { + return nil, fmt.Errorf("error reading options: %w", err) + } + + sess := opts.withAwsSession + if sess == nil { + sess, err = c.GetSession(opt...) + if err != nil { + return nil, fmt.Errorf("error calling GetSession: %w", err) + } + } + + if opts.withIAMAPIFunc != nil { + return opts.withIAMAPIFunc(sess) + } + + client := iam.New(sess) + if client == nil { + return nil, errors.New("could not obtain iam client from session") + } + + return client, nil +} + +// STSClient returns a STS client. +// +// Supported options: WithSession, WithSTSAPIFunc. +// +// If WithSTSAPIFunc is supplied, the included function is used as +// the STS client constructor instead. This can be used for Mocking +// the STS API. +func (c *CredentialsConfig) STSClient(opt ...Option) (stsiface.STSAPI, error) { + opts, err := getOpts(opt...) + if err != nil { + return nil, fmt.Errorf("error reading options: %w", err) + } + + sess := opts.withAwsSession + if sess == nil { + sess, err = c.GetSession(opt...) + if err != nil { + return nil, fmt.Errorf("error calling GetSession: %w", err) + } + } + + if opts.withSTSAPIFunc != nil { + return opts.withSTSAPIFunc(sess) + } + + client := sts.New(sess) + if client == nil { + return nil, errors.New("could not obtain sts client from session") + } + + return client, nil +} diff --git a/awsutil/clients_test.go b/awsutil/clients_test.go new file mode 100644 index 0000000..a6db1b2 --- /dev/null +++ b/awsutil/clients_test.go @@ -0,0 +1,141 @@ +package awsutil + +import ( + "errors" + "fmt" + "testing" + + "github.com/aws/aws-sdk-go/service/iam" + "github.com/aws/aws-sdk-go/service/iam/iamiface" + "github.com/aws/aws-sdk-go/service/sts" + "github.com/aws/aws-sdk-go/service/sts/stsiface" + "github.com/stretchr/testify/require" +) + +const testOptionErr = "test option error" +const testBadClientType = "badclienttype" + +func testWithOptionError(_ *options) error { + return errors.New(testOptionErr) +} + +func testWithBadClientType(o *options) error { + o.withClientType = testBadClientType + return nil +} + +func TestCredentialsConfigIAMClient(t *testing.T) { + cases := []struct { + name string + credentialsConfig *CredentialsConfig + opts []Option + require func(t *testing.T, actual iamiface.IAMAPI) + requireErr string + }{ + { + name: "options error", + credentialsConfig: &CredentialsConfig{}, + opts: []Option{testWithOptionError}, + requireErr: fmt.Sprintf("error reading options: %s", testOptionErr), + }, + { + name: "session error", + credentialsConfig: &CredentialsConfig{}, + opts: []Option{testWithBadClientType}, + requireErr: fmt.Sprintf("error calling GetSession: unknown client type %q in GetSession", testBadClientType), + }, + { + name: "with mock IAM session", + credentialsConfig: &CredentialsConfig{}, + opts: []Option{WithIAMAPIFunc(NewMockIAM())}, + require: func(t *testing.T, actual iamiface.IAMAPI) { + t.Helper() + require := require.New(t) + require.Equal(&MockIAM{}, actual) + }, + }, + { + name: "no mock client", + credentialsConfig: &CredentialsConfig{}, + opts: []Option{}, + require: func(t *testing.T, actual iamiface.IAMAPI) { + t.Helper() + require := require.New(t) + require.IsType(&iam.IAM{}, actual) + }, + }, + } + + for _, tc := range cases { + tc := tc + t.Run(tc.name, func(t *testing.T) { + require := require.New(t) + actual, err := tc.credentialsConfig.IAMClient(tc.opts...) + if tc.requireErr != "" { + require.EqualError(err, tc.requireErr) + return + } + + require.NoError(err) + tc.require(t, actual) + }) + } +} + +func TestCredentialsConfigSTSClient(t *testing.T) { + cases := []struct { + name string + credentialsConfig *CredentialsConfig + opts []Option + require func(t *testing.T, actual stsiface.STSAPI) + requireErr string + }{ + { + name: "options error", + credentialsConfig: &CredentialsConfig{}, + opts: []Option{testWithOptionError}, + requireErr: fmt.Sprintf("error reading options: %s", testOptionErr), + }, + { + name: "session error", + credentialsConfig: &CredentialsConfig{}, + opts: []Option{testWithBadClientType}, + requireErr: fmt.Sprintf("error calling GetSession: unknown client type %q in GetSession", testBadClientType), + }, + { + name: "with mock STS session", + credentialsConfig: &CredentialsConfig{}, + opts: []Option{WithSTSAPIFunc(NewMockSTS())}, + require: func(t *testing.T, actual stsiface.STSAPI) { + t.Helper() + require := require.New(t) + require.Equal(&MockSTS{}, actual) + }, + }, + { + name: "no mock client", + credentialsConfig: &CredentialsConfig{}, + opts: []Option{}, + require: func(t *testing.T, actual stsiface.STSAPI) { + t.Helper() + require := require.New(t) + require.IsType(&sts.STS{}, actual) + }, + }, + } + + for _, tc := range cases { + tc := tc + t.Run(tc.name, func(t *testing.T) { + require := require.New(t) + actual, err := tc.credentialsConfig.STSClient(tc.opts...) + if tc.requireErr != "" { + require.EqualError(err, tc.requireErr) + return + } + + require.NoError(err) + tc.require(t, actual) + }) + } +} diff --git a/awsutil/mocks.go b/awsutil/mocks.go index 43e4bb2..9ef1ad9 100644 --- a/awsutil/mocks.go +++ b/awsutil/mocks.go @@ -1,26 +1,157 @@ package awsutil import ( + "github.com/aws/aws-sdk-go/aws/session" "github.com/aws/aws-sdk-go/service/iam" "github.com/aws/aws-sdk-go/service/iam/iamiface" + "github.com/aws/aws-sdk-go/service/sts" + "github.com/aws/aws-sdk-go/service/sts/stsiface" ) +// MockIAM provides a way to mock the AWS IAM API. type MockIAM struct { iamiface.IAMAPI CreateAccessKeyOutput *iam.CreateAccessKeyOutput - DeleteAccessKeyOutput *iam.DeleteAccessKeyOutput + CreateAccessKeyError error + DeleteAccessKeyError error GetUserOutput *iam.GetUserOutput + GetUserError error +} + +// MockIAMOption is a function for setting the various fields on a MockIAM +// object. +type MockIAMOption func(m *MockIAM) error + +// WithCreateAccessKeyOutput sets the output for the CreateAccessKey method. +func WithCreateAccessKeyOutput(o *iam.CreateAccessKeyOutput) MockIAMOption { + return func(m *MockIAM) error { + m.CreateAccessKeyOutput = o + return nil + } +} + +// WithCreateAccessKeyError sets the error output for the CreateAccessKey +// method. +func WithCreateAccessKeyError(e error) MockIAMOption { + return func(m *MockIAM) error { + m.CreateAccessKeyError = e + return nil + } +} + +// WithDeleteAccessKeyError sets the error output for the DeleteAccessKey +// method. +func WithDeleteAccessKeyError(e error) MockIAMOption { + return func(m *MockIAM) error { + m.DeleteAccessKeyError = e + return nil + } +} + +// WithGetUserOutput sets the output for the GetUser method. +func WithGetUserOutput(o *iam.GetUserOutput) MockIAMOption { + return func(m *MockIAM) error { + m.GetUserOutput = o + return nil + } +} + +// WithGetUserError sets the error output for the GetUser method. +func WithGetUserError(e error) MockIAMOption { + return func(m *MockIAM) error { + m.GetUserError = e + return nil + } +} + +// NewMockIAM provides a factory function to use with the WithIAMAPIFunc +// option. +func NewMockIAM(opts ...MockIAMOption) IAMAPIFunc { + return func(_ *session.Session) (iamiface.IAMAPI, error) { + m := new(MockIAM) + for _, opt := range opts { + if err := opt(m); err != nil { + return nil, err + } + } + + return m, nil + } } func (m *MockIAM) CreateAccessKey(*iam.CreateAccessKeyInput) (*iam.CreateAccessKeyOutput, error) { + if m.CreateAccessKeyError != nil { + return nil, m.CreateAccessKeyError + } + return m.CreateAccessKeyOutput, nil } func (m *MockIAM) DeleteAccessKey(*iam.DeleteAccessKeyInput) (*iam.DeleteAccessKeyOutput, error) { - return m.DeleteAccessKeyOutput, nil + return &iam.DeleteAccessKeyOutput{}, m.DeleteAccessKeyError } func (m *MockIAM) GetUser(*iam.GetUserInput) (*iam.GetUserOutput, error) { + if m.GetUserError != nil { + return nil, m.GetUserError + } + return m.GetUserOutput, nil } + +// MockSTS provides a way to mock the AWS STS API. +type MockSTS struct { + stsiface.STSAPI + + GetCallerIdentityOutput *sts.GetCallerIdentityOutput + GetCallerIdentityError error +} + +// MockSTSOption is a function for setting the various fields on a MockSTS +// object. +type MockSTSOption func(m *MockSTS) error + +// WithGetCallerIdentityOutput sets the output for the GetCallerIdentity +// method. +func WithGetCallerIdentityOutput(o *sts.GetCallerIdentityOutput) MockSTSOption { + return func(m *MockSTS) error { + m.GetCallerIdentityOutput = o + return nil + } +} + +// WithGetCallerIdentityError sets the error output for the GetCallerIdentity +// method. +func WithGetCallerIdentityError(e error) MockSTSOption { + return func(m *MockSTS) error { + m.GetCallerIdentityError = e + return nil + } +} + +// NewMockSTS provides a factory function to use with the WithSTSAPIFunc +// option. +// +// If withGetCallerIdentityError is supplied, calls to GetCallerIdentity will +// return the supplied error. Otherwise, a basic mock API output is returned. +func NewMockSTS(opts ...MockSTSOption) STSAPIFunc { + return func(_ *session.Session) (stsiface.STSAPI, error) { + m := new(MockSTS) + for _, opt := range opts { + if err := opt(m); err != nil { + return nil, err + } + } + + return m, nil + } +} + +func (m *MockSTS) GetCallerIdentity(_ *sts.GetCallerIdentityInput) (*sts.GetCallerIdentityOutput, error) { + if m.GetCallerIdentityError != nil { + return nil, m.GetCallerIdentityError + } + + return m.GetCallerIdentityOutput, nil +} diff --git a/awsutil/mocks_test.go b/awsutil/mocks_test.go new file mode 100644 index 0000000..85b3e60 --- /dev/null +++ b/awsutil/mocks_test.go @@ -0,0 +1,141 @@ +package awsutil + +import ( + "errors" + "testing" + + "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go/service/iam" + "github.com/aws/aws-sdk-go/service/sts" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestMockIAM(t *testing.T) { + cases := []struct { + name string + opts []MockIAMOption + expectedCreateAccessKeyOutput *iam.CreateAccessKeyOutput + expectedCreateAccessKeyError error + expectedDeleteAccessKeyError error + expectedGetUserOutput *iam.GetUserOutput + expectedGetUserError error + }{ + { + name: "CreateAccessKeyOutput", + opts: []MockIAMOption{WithCreateAccessKeyOutput( + &iam.CreateAccessKeyOutput{ + AccessKey: &iam.AccessKey{ + AccessKeyId: aws.String("foobar"), + SecretAccessKey: aws.String("bazqux"), + }, + }, + )}, + expectedCreateAccessKeyOutput: &iam.CreateAccessKeyOutput{ + AccessKey: &iam.AccessKey{ + AccessKeyId: aws.String("foobar"), + SecretAccessKey: aws.String("bazqux"), + }, + }, + }, + { + name: "CreateAccessKeyError", + opts: []MockIAMOption{WithCreateAccessKeyError(errors.New("testerr"))}, + expectedCreateAccessKeyError: errors.New("testerr"), + }, + { + name: "DeleteAccessKeyError", + opts: []MockIAMOption{WithDeleteAccessKeyError(errors.New("testerr"))}, + expectedDeleteAccessKeyError: errors.New("testerr"), + }, + { + name: "GetUserOutput", + opts: []MockIAMOption{WithGetUserOutput( + &iam.GetUserOutput{ + User: &iam.User{ + Arn: aws.String("arn:aws:iam::123456789012:user/JohnDoe"), + UserId: aws.String("AIDAJQABLZS4A3QDU576Q"), + UserName: aws.String("JohnDoe"), + }, + }, + )}, + expectedGetUserOutput: &iam.GetUserOutput{ + User: &iam.User{ + Arn: aws.String("arn:aws:iam::123456789012:user/JohnDoe"), + UserId: aws.String("AIDAJQABLZS4A3QDU576Q"), + UserName: aws.String("JohnDoe"), + }, + }, + }, + { + name: "GetUserError", + opts: []MockIAMOption{WithGetUserError(errors.New("testerr"))}, + expectedGetUserError: errors.New("testerr"), + }, + } + + for _, tc := range cases { + tc := tc + t.Run(tc.name, func(t *testing.T) { + assert := assert.New(t) + require := require.New(t) + + f := NewMockIAM(tc.opts...) + m, err := f(nil) + require.NoError(err) // Nothing returns an error right now + actualCreateAccessKeyOutput, actualCreateAccessKeyError := m.CreateAccessKey(nil) + _, actualDeleteAccessKeyError := m.DeleteAccessKey(nil) + actualGetUserOutput, actualGetUserError := m.GetUser(nil) + assert.Equal(tc.expectedCreateAccessKeyOutput, actualCreateAccessKeyOutput) + assert.Equal(tc.expectedCreateAccessKeyError, actualCreateAccessKeyError) + assert.Equal(tc.expectedDeleteAccessKeyError, actualDeleteAccessKeyError) + assert.Equal(tc.expectedGetUserOutput, actualGetUserOutput) + assert.Equal(tc.expectedGetUserError, actualGetUserError) + }) + } +} + +func TestMockSTS(t *testing.T) { + cases := []struct { + name string + opts []MockSTSOption + expectedGetCallerIdentityOutput *sts.GetCallerIdentityOutput + expectedGetCallerIdentityError error + }{ + { + name: "GetCallerIdentityOutput", + opts: []MockSTSOption{WithGetCallerIdentityOutput( + &sts.GetCallerIdentityOutput{ + Account: aws.String("1234567890"), + Arn: aws.String("arn:aws:iam::123456789012:user/JohnDoe"), + UserId: aws.String("AIDAJQABLZS4A3QDU576Q"), + }, + )}, + expectedGetCallerIdentityOutput: &sts.GetCallerIdentityOutput{ + Account: aws.String("1234567890"), + Arn: aws.String("arn:aws:iam::123456789012:user/JohnDoe"), + UserId: aws.String("AIDAJQABLZS4A3QDU576Q"), + }, + }, + { + name: "GetCallerIdentityError", + opts: []MockSTSOption{WithGetCallerIdentityError(errors.New("testerr"))}, + expectedGetCallerIdentityError: errors.New("testerr"), + }, + } + + for _, tc := range cases { + tc := tc + t.Run(tc.name, func(t *testing.T) { + assert := assert.New(t) + require := require.New(t) + + f := NewMockSTS(tc.opts...) + m, err := f(nil) + require.NoError(err) // Nothing returns an error right now + actualGetCallerIdentityOutput, actualGetCallerIdentityError := m.GetCallerIdentity(nil) + assert.Equal(tc.expectedGetCallerIdentityOutput, actualGetCallerIdentityOutput) + assert.Equal(tc.expectedGetCallerIdentityError, actualGetCallerIdentityError) + }) + } +} diff --git a/awsutil/options.go b/awsutil/options.go index d3da48f..0d66a16 100644 --- a/awsutil/options.go +++ b/awsutil/options.go @@ -42,6 +42,8 @@ type options struct { withRegion string withHttpClient *http.Client withValidityCheckTimeout time.Duration + withIAMAPIFunc IAMAPIFunc + withSTSAPIFunc STSAPIFunc } func getDefaultOptions() options { @@ -170,3 +172,21 @@ func WithValidityCheckTimeout(with time.Duration) Option { return nil } } + +// WithIAMAPIFunc allows passing in an IAM interface constructor for mocking +// the AWS IAM API. +func WithIAMAPIFunc(with IAMAPIFunc) Option { + return func(o *options) error { + o.withIAMAPIFunc = with + return nil + } +} + +// WithSTSAPIFunc allows passing in a STS interface constructor for mocking the +// AWS STS API. +func WithSTSAPIFunc(with STSAPIFunc) Option { + return func(o *options) error { + o.withSTSAPIFunc = with + return nil + } +} diff --git a/awsutil/options_test.go b/awsutil/options_test.go index 65fe352..9e3169e 100644 --- a/awsutil/options_test.go +++ b/awsutil/options_test.go @@ -118,4 +118,14 @@ func Test_GetOpts(t *testing.T) { require.NoError(t, err) assert.Equal(t, opts.withValidityCheckTimeout, time.Second) }) + t.Run("withIAMIface", func(t *testing.T) { + opts, err := getOpts(WithIAMAPIFunc(NewMockIAM())) + require.NoError(t, err) + assert.NotNil(t, opts.withIAMAPIFunc) + }) + t.Run("withSTSIface", func(t *testing.T) { + opts, err := getOpts(WithSTSAPIFunc(NewMockSTS())) + require.NoError(t, err) + assert.NotNil(t, opts.withSTSAPIFunc) + }) } diff --git a/awsutil/rotate.go b/awsutil/rotate.go index 0fbc7d2..a92fd62 100644 --- a/awsutil/rotate.go +++ b/awsutil/rotate.go @@ -23,9 +23,12 @@ import ( // if the old one could not be deleted. // // Supported options: WithEnvironmentCredentials, WithSharedCredentials, -// WithAwsSession, WithUsername, WithValidityCheckTimeout. Note that WithValidityCheckTimeout -// here, when non-zero, controls the WithValidityCheckTimeout option on access key -// creation. See CreateAccessKey for more details. +// WithAwsSession, WithUsername, WithValidityCheckTimeout, WithIAMAPIFunc, +// WithSTSAPIFunc +// +// Note that WithValidityCheckTimeout here, when non-zero, controls the +// WithValidityCheckTimeout option on access key creation. See CreateAccessKey +// for more details. func (c *CredentialsConfig) RotateKeys(opt ...Option) error { if c.AccessKey == "" || c.SecretKey == "" { return errors.New("cannot rotate credentials when either access_key or secret_key is empty") @@ -64,7 +67,8 @@ func (c *CredentialsConfig) RotateKeys(opt ...Option) error { // CreateAccessKey creates a new access/secret key pair. // // Supported options: WithEnvironmentCredentials, WithSharedCredentials, -// WithAwsSession, WithUsername, WithValidityCheckTimeout +// WithAwsSession, WithUsername, WithValidityCheckTimeout, WithIAMAPIFunc, +// WithSTSAPIFunc // // When WithValidityCheckTimeout is non-zero, it specifies a timeout to wait on // the created credentials to be valid and ready for use. @@ -74,17 +78,9 @@ func (c *CredentialsConfig) CreateAccessKey(opt ...Option) (*iam.CreateAccessKey return nil, fmt.Errorf("error reading options in CreateAccessKey: %w", err) } - sess := opts.withAwsSession - if sess == nil { - sess, err = c.GetSession(opt...) - if err != nil { - return nil, fmt.Errorf("error calling GetSession: %w", err) - } - } - - client := iam.New(sess) - if client == nil { - return nil, errors.New("could not obtain iam client from session") + client, err := c.IAMClient(opt...) + if err != nil { + return nil, fmt.Errorf("error loading IAM client: %w", err) } var getUserInput iam.GetUserInput @@ -112,9 +108,12 @@ func (c *CredentialsConfig) CreateAccessKey(opt ...Option) (*iam.CreateAccessKey if err != nil { return nil, fmt.Errorf("error calling aws.CreateAccessKey: %w", err) } - if createAccessKeyRes.AccessKey == nil { + if createAccessKeyRes == nil { return nil, fmt.Errorf("nil response from aws.CreateAccessKey") } + if createAccessKeyRes.AccessKey == nil { + return nil, fmt.Errorf("nil access key in response from aws.CreateAccessKey") + } if createAccessKeyRes.AccessKey.AccessKeyId == nil || createAccessKeyRes.AccessKey.SecretAccessKey == nil { return nil, fmt.Errorf("nil AccessKeyId or SecretAccessKey returned from aws.CreateAccessKey") } @@ -128,7 +127,10 @@ func (c *CredentialsConfig) CreateAccessKey(opt ...Option) (*iam.CreateAccessKey SecretKey: *createAccessKeyRes.AccessKey.SecretAccessKey, } - if _, err := newC.GetCallerIdentity(WithValidityCheckTimeout(opts.withValidityCheckTimeout)); err != nil { + if _, err := newC.GetCallerIdentity( + WithValidityCheckTimeout(opts.withValidityCheckTimeout), + WithSTSAPIFunc(opts.withSTSAPIFunc), + ); err != nil { return nil, fmt.Errorf("error verifying new credentials: %w", err) } } @@ -139,24 +141,16 @@ func (c *CredentialsConfig) CreateAccessKey(opt ...Option) (*iam.CreateAccessKey // DeleteAccessKey deletes an access key. // // Supported options: WithEnvironmentCredentials, WithSharedCredentials, -// WithAwsSession, WithUserName +// WithAwsSession, WithUserName, WithIAMAPIFunc func (c *CredentialsConfig) DeleteAccessKey(accessKeyId string, opt ...Option) error { opts, err := getOpts(opt...) if err != nil { return fmt.Errorf("error reading options in RotateKeys: %w", err) } - sess := opts.withAwsSession - if sess == nil { - sess, err = c.GetSession(opt...) - if err != nil { - return fmt.Errorf("error calling GetSession: %w", err) - } - } - - client := iam.New(sess) - if client == nil { - return errors.New("could not obtain iam client from session") + client, err := c.IAMClient(opt...) + if err != nil { + return fmt.Errorf("error loading IAM client: %w", err) } deleteAccessKeyInput := iam.DeleteAccessKeyInput{ @@ -230,17 +224,9 @@ func (c *CredentialsConfig) GetCallerIdentity(opt ...Option) (*sts.GetCallerIden return nil, fmt.Errorf("error reading options in GetCallerIdentity: %w", err) } - sess := opts.withAwsSession - if sess == nil { - sess, err = c.GetSession(opt...) - if err != nil { - return nil, fmt.Errorf("error calling GetSession: %w", err) - } - } - - client := sts.New(sess) - if client == nil { - return nil, errors.New("could not obtain STS client from session") + client, err := c.STSClient(opt...) + if err != nil { + return nil, fmt.Errorf("error loading STS client: %w", err) } delay := time.Second diff --git a/awsutil/rotate_test.go b/awsutil/rotate_test.go index 858b34b..557384d 100644 --- a/awsutil/rotate_test.go +++ b/awsutil/rotate_test.go @@ -7,7 +7,10 @@ import ( "testing" "time" + "github.com/aws/aws-sdk-go/aws" "github.com/aws/aws-sdk-go/aws/awserr" + "github.com/aws/aws-sdk-go/service/iam" + "github.com/aws/aws-sdk-go/service/sts" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -131,3 +134,203 @@ func TestCallerIdentityErrorWithValidityCheckTimeout(t *testing.T) { require.NotNil(err) require.Implements((*awserr.Error)(nil), err) } + +func TestCallerIdentityWithSTSMockError(t *testing.T) { + require := require.New(t) + + expectedErr := errors.New("this is the expected error") + c, err := NewCredentialsConfig() + require.NoError(err) + _, err = c.GetCallerIdentity(WithSTSAPIFunc(NewMockSTS(WithGetCallerIdentityError(expectedErr)))) + require.EqualError(err, expectedErr.Error()) +} + +func TestCallerIdentityWithSTSMockNoErorr(t *testing.T) { + require := require.New(t) + + expectedOut := &sts.GetCallerIdentityOutput{ + Account: aws.String("1234567890"), + Arn: aws.String("arn:aws:iam::123456789012:user/JohnDoe"), + UserId: aws.String("AIDAJQABLZS4A3QDU576Q"), + } + + c, err := NewCredentialsConfig() + require.NoError(err) + out, err := c.GetCallerIdentity(WithSTSAPIFunc(NewMockSTS(WithGetCallerIdentityOutput(expectedOut)))) + require.NoError(err) + require.Equal(expectedOut, out) +} + +func TestDeleteAccessKeyWithIAMMock(t *testing.T) { + require := require.New(t) + + mockErr := errors.New("this is the expected error") + expectedErr := "error deleting old access key: this is the expected error" + c, err := NewCredentialsConfig() + require.NoError(err) + err = c.DeleteAccessKey("foobar", WithIAMAPIFunc(NewMockIAM(WithDeleteAccessKeyError(mockErr)))) + require.EqualError(err, expectedErr) +} + +func TestCreateAccessKeyWithIAMMockGetUserError(t *testing.T) { + require := require.New(t) + + mockErr := errors.New("this is the expected error") + expectedErr := "error calling aws.GetUser: this is the expected error" + c, err := NewCredentialsConfig() + require.NoError(err) + _, err = c.CreateAccessKey(WithIAMAPIFunc(NewMockIAM(WithGetUserError(mockErr)))) + require.EqualError(err, expectedErr) +} + +func TestCreateAccessKeyWithIAMMockCreateAccessKeyError(t *testing.T) { + require := require.New(t) + + mockErr := errors.New("this is the expected error") + expectedErr := "error calling aws.CreateAccessKey: this is the expected error" + c, err := NewCredentialsConfig() + require.NoError(err) + _, err = c.CreateAccessKey(WithIAMAPIFunc(NewMockIAM( + WithGetUserOutput(&iam.GetUserOutput{ + User: &iam.User{ + UserName: aws.String("foobar"), + }, + }), + WithCreateAccessKeyError(mockErr), + ))) + require.EqualError(err, expectedErr) +} + +func TestCreateAccessKeyWithIAMAndSTSMockGetCallerIdentityError(t *testing.T) { + require := require.New(t) + + mockErr := errors.New("this is the expected error") + expectedErr := "error verifying new credentials: timeout after 1ns waiting for success: this is the expected error" + c, err := NewCredentialsConfig() + require.NoError(err) + _, err = c.CreateAccessKey( + WithValidityCheckTimeout(time.Nanosecond), + WithIAMAPIFunc(NewMockIAM( + WithGetUserOutput(&iam.GetUserOutput{ + User: &iam.User{ + UserName: aws.String("foobar"), + }, + }), + WithCreateAccessKeyOutput(&iam.CreateAccessKeyOutput{ + AccessKey: &iam.AccessKey{ + AccessKeyId: aws.String("foobar"), + SecretAccessKey: aws.String("bazqux"), + }, + }), + )), + WithSTSAPIFunc(NewMockSTS( + WithGetCallerIdentityError(mockErr), + )), + ) + require.EqualError(err, expectedErr) +} + +func TestCreateAccessKeyNilResponse(t *testing.T) { + require := require.New(t) + + expectedErr := "nil response from aws.CreateAccessKey" + c, err := NewCredentialsConfig() + require.NoError(err) + _, err = c.CreateAccessKey( + WithValidityCheckTimeout(time.Nanosecond), + WithIAMAPIFunc(NewMockIAM( + WithGetUserOutput(&iam.GetUserOutput{ + User: &iam.User{ + UserName: aws.String("foobar"), + }, + }), + )), + ) + require.EqualError(err, expectedErr) +} + +func TestRotateKeysWithMocks(t *testing.T) { + mockErr := errors.New("this is the expected error") + cases := []struct { + name string + mockIAMOpts []MockIAMOption + mockSTSOpts []MockSTSOption + require func(t *testing.T, actual *CredentialsConfig) + requireErr string + }{ + { + name: "CreateAccessKey IAM error", + mockIAMOpts: []MockIAMOption{WithGetUserError(mockErr)}, + requireErr: "error calling CreateAccessKey: error calling aws.GetUser: this is the expected error", + }, + { + name: "CreateAccessKey STS error", + mockIAMOpts: []MockIAMOption{ + WithGetUserOutput(&iam.GetUserOutput{ + User: &iam.User{ + UserName: aws.String("foobar"), + }, + }), + WithCreateAccessKeyOutput(&iam.CreateAccessKeyOutput{ + AccessKey: &iam.AccessKey{ + AccessKeyId: aws.String("foobar"), + SecretAccessKey: aws.String("bazqux"), + }, + }), + }, + mockSTSOpts: []MockSTSOption{WithGetCallerIdentityError(mockErr)}, + requireErr: "error calling CreateAccessKey: error verifying new credentials: timeout after 1ns waiting for success: this is the expected error", + }, + { + name: "DeleteAccessKey IAM error", + mockIAMOpts: []MockIAMOption{ + WithGetUserOutput(&iam.GetUserOutput{ + User: &iam.User{ + UserName: aws.String("foobar"), + }, + }), + WithCreateAccessKeyOutput(&iam.CreateAccessKeyOutput{ + AccessKey: &iam.AccessKey{ + AccessKeyId: aws.String("foobar"), + SecretAccessKey: aws.String("bazqux"), + UserName: aws.String("foouser"), + }, + }), + // DeleteAccessKeyOutput w/o error is a no-op in the mock and + // will return without additional stubbing + }, + mockSTSOpts: []MockSTSOption{WithGetCallerIdentityOutput(&sts.GetCallerIdentityOutput{})}, + require: func(t *testing.T, actual *CredentialsConfig) { + t.Helper() + require := require.New(t) + + require.Equal("foobar", actual.AccessKey) + require.Equal("bazqux", actual.SecretKey) + }, + }, + } + + for _, tc := range cases { + tc := tc + t.Run(tc.name, func(t *testing.T) { + require := require.New(t) + c, err := NewCredentialsConfig( + WithAccessKey("foo"), + WithSecretKey("bar"), + ) + require.NoError(err) + err = c.RotateKeys( + WithIAMAPIFunc(NewMockIAM(tc.mockIAMOpts...)), + WithSTSAPIFunc(NewMockSTS(tc.mockSTSOpts...)), + WithValidityCheckTimeout(time.Nanosecond), + ) + if tc.requireErr != "" { + require.EqualError(err, tc.requireErr) + return + } + + require.NoError(err) + tc.require(t, c) + }) + } +}