From 79c4376dbe7cc862588eceafdf09c75344c1975c Mon Sep 17 00:00:00 2001 From: Chris Marchesi Date: Thu, 30 Sep 2021 17:02:12 -0700 Subject: [PATCH] awsutil: add ability to mock IAM and STS APIs This adds the ability to supply mock IAM and STS interfaces to use with the RotateKeys, CreateAccessKey, DeleteAccessKey, and GetCallerIdentity methods. Additionally, work has been done to incorporate the existing IAM mock object into the package, in addition to adding a similar STS mock that allows for the mocking of the GetCallerIdentity method. Factory functions are also included to allow these to be incorporated seamlessly into the call path, in addition to allowing for introspection into any session data being received by the constructor functions. --- awsutil/clients.go | 90 ++++++++++++++++++ awsutil/clients_test.go | 141 ++++++++++++++++++++++++++++ awsutil/mocks.go | 135 +++++++++++++++++++++++++- awsutil/mocks_test.go | 141 ++++++++++++++++++++++++++++ awsutil/options.go | 20 ++++ awsutil/options_test.go | 10 ++ awsutil/rotate.go | 66 +++++-------- awsutil/rotate_test.go | 203 ++++++++++++++++++++++++++++++++++++++++ 8 files changed, 764 insertions(+), 42 deletions(-) create mode 100644 awsutil/clients.go create mode 100644 awsutil/clients_test.go create mode 100644 awsutil/mocks_test.go diff --git a/awsutil/clients.go b/awsutil/clients.go new file mode 100644 index 0000000..bb27f4c --- /dev/null +++ b/awsutil/clients.go @@ -0,0 +1,90 @@ +package awsutil + +import ( + "errors" + "fmt" + + "github.com/aws/aws-sdk-go/aws/session" + "github.com/aws/aws-sdk-go/service/iam" + "github.com/aws/aws-sdk-go/service/iam/iamiface" + "github.com/aws/aws-sdk-go/service/sts" + "github.com/aws/aws-sdk-go/service/sts/stsiface" +) + +// IAMAPIFunc is a factory function for returning an IAM interface, +// useful for supplying mock interfaces for testing IAM. The session +// is passed into the function in the same way as done with the +// standard iam.New() constructor. +type IAMAPIFunc func(sess *session.Session) (iamiface.IAMAPI, error) + +// STSAPIFunc is a factory function for returning a STS interface, +// useful for supplying mock interfaces for testing STS. The session +// is passed into the function in the same way as done with the +// standard sts.New() constructor. +type STSAPIFunc func(sess *session.Session) (stsiface.STSAPI, error) + +// IAMClient returns an IAM client. +// +// Supported options: WithSession, WithIAMAPIFunc. +// +// If WithIAMAPIFunc is supplied, the included function is used as +// the IAM client constructor instead. This can be used for Mocking +// the IAM API. +func (c *CredentialsConfig) IAMClient(opt ...Option) (iamiface.IAMAPI, error) { + opts, err := getOpts(opt...) + if err != nil { + return nil, fmt.Errorf("error reading options: %w", err) + } + + sess := opts.withAwsSession + if sess == nil { + sess, err = c.GetSession(opt...) + if err != nil { + return nil, fmt.Errorf("error calling GetSession: %w", err) + } + } + + if opts.withIAMAPIFunc != nil { + return opts.withIAMAPIFunc(sess) + } + + client := iam.New(sess) + if client == nil { + return nil, errors.New("could not obtain iam client from session") + } + + return client, nil +} + +// STSClient returns a STS client. +// +// Supported options: WithSession, WithSTSAPIFunc. +// +// If WithSTSAPIFunc is supplied, the included function is used as +// the STS client constructor instead. This can be used for Mocking +// the STS API. +func (c *CredentialsConfig) STSClient(opt ...Option) (stsiface.STSAPI, error) { + opts, err := getOpts(opt...) + if err != nil { + return nil, fmt.Errorf("error reading options: %w", err) + } + + sess := opts.withAwsSession + if sess == nil { + sess, err = c.GetSession(opt...) + if err != nil { + return nil, fmt.Errorf("error calling GetSession: %w", err) + } + } + + if opts.withSTSAPIFunc != nil { + return opts.withSTSAPIFunc(sess) + } + + client := sts.New(sess) + if client == nil { + return nil, errors.New("could not obtain sts client from session") + } + + return client, nil +} diff --git a/awsutil/clients_test.go b/awsutil/clients_test.go new file mode 100644 index 0000000..a6db1b2 --- /dev/null +++ b/awsutil/clients_test.go @@ -0,0 +1,141 @@ +package awsutil + +import ( + "errors" + "fmt" + "testing" + + "github.com/aws/aws-sdk-go/service/iam" + "github.com/aws/aws-sdk-go/service/iam/iamiface" + "github.com/aws/aws-sdk-go/service/sts" + "github.com/aws/aws-sdk-go/service/sts/stsiface" + "github.com/stretchr/testify/require" +) + +const testOptionErr = "test option error" +const testBadClientType = "badclienttype" + +func testWithOptionError(_ *options) error { + return errors.New(testOptionErr) +} + +func testWithBadClientType(o *options) error { + o.withClientType = testBadClientType + return nil +} + +func TestCredentialsConfigIAMClient(t *testing.T) { + cases := []struct { + name string + credentialsConfig *CredentialsConfig + opts []Option + require func(t *testing.T, actual iamiface.IAMAPI) + requireErr string + }{ + { + name: "options error", + credentialsConfig: &CredentialsConfig{}, + opts: []Option{testWithOptionError}, + requireErr: fmt.Sprintf("error reading options: %s", testOptionErr), + }, + { + name: "session error", + credentialsConfig: &CredentialsConfig{}, + opts: []Option{testWithBadClientType}, + requireErr: fmt.Sprintf("error calling GetSession: unknown client type %q in GetSession", testBadClientType), + }, + { + name: "with mock IAM session", + credentialsConfig: &CredentialsConfig{}, + opts: []Option{WithIAMAPIFunc(NewMockIAM())}, + require: func(t *testing.T, actual iamiface.IAMAPI) { + t.Helper() + require := require.New(t) + require.Equal(&MockIAM{}, actual) + }, + }, + { + name: "no mock client", + credentialsConfig: &CredentialsConfig{}, + opts: []Option{}, + require: func(t *testing.T, actual iamiface.IAMAPI) { + t.Helper() + require := require.New(t) + require.IsType(&iam.IAM{}, actual) + }, + }, + } + + for _, tc := range cases { + tc := tc + t.Run(tc.name, func(t *testing.T) { + require := require.New(t) + actual, err := tc.credentialsConfig.IAMClient(tc.opts...) + if tc.requireErr != "" { + require.EqualError(err, tc.requireErr) + return + } + + require.NoError(err) + tc.require(t, actual) + }) + } +} + +func TestCredentialsConfigSTSClient(t *testing.T) { + cases := []struct { + name string + credentialsConfig *CredentialsConfig + opts []Option + require func(t *testing.T, actual stsiface.STSAPI) + requireErr string + }{ + { + name: "options error", + credentialsConfig: &CredentialsConfig{}, + opts: []Option{testWithOptionError}, + requireErr: fmt.Sprintf("error reading options: %s", testOptionErr), + }, + { + name: "session error", + credentialsConfig: &CredentialsConfig{}, + opts: []Option{testWithBadClientType}, + requireErr: fmt.Sprintf("error calling GetSession: unknown client type %q in GetSession", testBadClientType), + }, + { + name: "with mock STS session", + credentialsConfig: &CredentialsConfig{}, + opts: []Option{WithSTSAPIFunc(NewMockSTS())}, + require: func(t *testing.T, actual stsiface.STSAPI) { + t.Helper() + require := require.New(t) + require.Equal(&MockSTS{}, actual) + }, + }, + { + name: "no mock client", + credentialsConfig: &CredentialsConfig{}, + opts: []Option{}, + require: func(t *testing.T, actual stsiface.STSAPI) { + t.Helper() + require := require.New(t) + require.IsType(&sts.STS{}, actual) + }, + }, + } + + for _, tc := range cases { + tc := tc + t.Run(tc.name, func(t *testing.T) { + require := require.New(t) + actual, err := tc.credentialsConfig.STSClient(tc.opts...) + if tc.requireErr != "" { + require.EqualError(err, tc.requireErr) + return + } + + require.NoError(err) + tc.require(t, actual) + }) + } +} diff --git a/awsutil/mocks.go b/awsutil/mocks.go index 43e4bb2..9ef1ad9 100644 --- a/awsutil/mocks.go +++ b/awsutil/mocks.go @@ -1,26 +1,157 @@ package awsutil import ( + "github.com/aws/aws-sdk-go/aws/session" "github.com/aws/aws-sdk-go/service/iam" "github.com/aws/aws-sdk-go/service/iam/iamiface" + "github.com/aws/aws-sdk-go/service/sts" + "github.com/aws/aws-sdk-go/service/sts/stsiface" ) +// MockIAM provides a way to mock the AWS IAM API. type MockIAM struct { iamiface.IAMAPI CreateAccessKeyOutput *iam.CreateAccessKeyOutput - DeleteAccessKeyOutput *iam.DeleteAccessKeyOutput + CreateAccessKeyError error + DeleteAccessKeyError error GetUserOutput *iam.GetUserOutput + GetUserError error +} + +// MockIAMOption is a function for setting the various fields on a MockIAM +// object. +type MockIAMOption func(m *MockIAM) error + +// WithCreateAccessKeyOutput sets the output for the CreateAccessKey method. +func WithCreateAccessKeyOutput(o *iam.CreateAccessKeyOutput) MockIAMOption { + return func(m *MockIAM) error { + m.CreateAccessKeyOutput = o + return nil + } +} + +// WithCreateAccessKeyError sets the error output for the CreateAccessKey +// method. +func WithCreateAccessKeyError(e error) MockIAMOption { + return func(m *MockIAM) error { + m.CreateAccessKeyError = e + return nil + } +} + +// WithDeleteAccessKeyError sets the error output for the DeleteAccessKey +// method. +func WithDeleteAccessKeyError(e error) MockIAMOption { + return func(m *MockIAM) error { + m.DeleteAccessKeyError = e + return nil + } +} + +// WithGetUserOutput sets the output for the GetUser method. +func WithGetUserOutput(o *iam.GetUserOutput) MockIAMOption { + return func(m *MockIAM) error { + m.GetUserOutput = o + return nil + } +} + +// WithGetUserError sets the error output for the GetUser method. +func WithGetUserError(e error) MockIAMOption { + return func(m *MockIAM) error { + m.GetUserError = e + return nil + } +} + +// NewMockIAM provides a factory function to use with the WithIAMAPIFunc +// option. +func NewMockIAM(opts ...MockIAMOption) IAMAPIFunc { + return func(_ *session.Session) (iamiface.IAMAPI, error) { + m := new(MockIAM) + for _, opt := range opts { + if err := opt(m); err != nil { + return nil, err + } + } + + return m, nil + } } func (m *MockIAM) CreateAccessKey(*iam.CreateAccessKeyInput) (*iam.CreateAccessKeyOutput, error) { + if m.CreateAccessKeyError != nil { + return nil, m.CreateAccessKeyError + } + return m.CreateAccessKeyOutput, nil } func (m *MockIAM) DeleteAccessKey(*iam.DeleteAccessKeyInput) (*iam.DeleteAccessKeyOutput, error) { - return m.DeleteAccessKeyOutput, nil + return &iam.DeleteAccessKeyOutput{}, m.DeleteAccessKeyError } func (m *MockIAM) GetUser(*iam.GetUserInput) (*iam.GetUserOutput, error) { + if m.GetUserError != nil { + return nil, m.GetUserError + } + return m.GetUserOutput, nil } + +// MockSTS provides a way to mock the AWS STS API. +type MockSTS struct { + stsiface.STSAPI + + GetCallerIdentityOutput *sts.GetCallerIdentityOutput + GetCallerIdentityError error +} + +// MockSTSOption is a function for setting the various fields on a MockSTS +// object. +type MockSTSOption func(m *MockSTS) error + +// WithGetCallerIdentityOutput sets the output for the GetCallerIdentity +// method. +func WithGetCallerIdentityOutput(o *sts.GetCallerIdentityOutput) MockSTSOption { + return func(m *MockSTS) error { + m.GetCallerIdentityOutput = o + return nil + } +} + +// WithGetCallerIdentityError sets the error output for the GetCallerIdentity +// method. +func WithGetCallerIdentityError(e error) MockSTSOption { + return func(m *MockSTS) error { + m.GetCallerIdentityError = e + return nil + } +} + +// NewMockSTS provides a factory function to use with the WithSTSAPIFunc +// option. +// +// If withGetCallerIdentityError is supplied, calls to GetCallerIdentity will +// return the supplied error. Otherwise, a basic mock API output is returned. +func NewMockSTS(opts ...MockSTSOption) STSAPIFunc { + return func(_ *session.Session) (stsiface.STSAPI, error) { + m := new(MockSTS) + for _, opt := range opts { + if err := opt(m); err != nil { + return nil, err + } + } + + return m, nil + } +} + +func (m *MockSTS) GetCallerIdentity(_ *sts.GetCallerIdentityInput) (*sts.GetCallerIdentityOutput, error) { + if m.GetCallerIdentityError != nil { + return nil, m.GetCallerIdentityError + } + + return m.GetCallerIdentityOutput, nil +} diff --git a/awsutil/mocks_test.go b/awsutil/mocks_test.go new file mode 100644 index 0000000..85b3e60 --- /dev/null +++ b/awsutil/mocks_test.go @@ -0,0 +1,141 @@ +package awsutil + +import ( + "errors" + "testing" + + "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go/service/iam" + "github.com/aws/aws-sdk-go/service/sts" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestMockIAM(t *testing.T) { + cases := []struct { + name string + opts []MockIAMOption + expectedCreateAccessKeyOutput *iam.CreateAccessKeyOutput + expectedCreateAccessKeyError error + expectedDeleteAccessKeyError error + expectedGetUserOutput *iam.GetUserOutput + expectedGetUserError error + }{ + { + name: "CreateAccessKeyOutput", + opts: []MockIAMOption{WithCreateAccessKeyOutput( + &iam.CreateAccessKeyOutput{ + AccessKey: &iam.AccessKey{ + AccessKeyId: aws.String("foobar"), + SecretAccessKey: aws.String("bazqux"), + }, + }, + )}, + expectedCreateAccessKeyOutput: &iam.CreateAccessKeyOutput{ + AccessKey: &iam.AccessKey{ + AccessKeyId: aws.String("foobar"), + SecretAccessKey: aws.String("bazqux"), + }, + }, + }, + { + name: "CreateAccessKeyError", + opts: []MockIAMOption{WithCreateAccessKeyError(errors.New("testerr"))}, + expectedCreateAccessKeyError: errors.New("testerr"), + }, + { + name: "DeleteAccessKeyError", + opts: []MockIAMOption{WithDeleteAccessKeyError(errors.New("testerr"))}, + expectedDeleteAccessKeyError: errors.New("testerr"), + }, + { + name: "GetUserOutput", + opts: []MockIAMOption{WithGetUserOutput( + &iam.GetUserOutput{ + User: &iam.User{ + Arn: aws.String("arn:aws:iam::123456789012:user/JohnDoe"), + UserId: aws.String("AIDAJQABLZS4A3QDU576Q"), + UserName: aws.String("JohnDoe"), + }, + }, + )}, + expectedGetUserOutput: &iam.GetUserOutput{ + User: &iam.User{ + Arn: aws.String("arn:aws:iam::123456789012:user/JohnDoe"), + UserId: aws.String("AIDAJQABLZS4A3QDU576Q"), + UserName: aws.String("JohnDoe"), + }, + }, + }, + { + name: "GetUserError", + opts: []MockIAMOption{WithGetUserError(errors.New("testerr"))}, + expectedGetUserError: errors.New("testerr"), + }, + } + + for _, tc := range cases { + tc := tc + t.Run(tc.name, func(t *testing.T) { + assert := assert.New(t) + require := require.New(t) + + f := NewMockIAM(tc.opts...) + m, err := f(nil) + require.NoError(err) // Nothing returns an error right now + actualCreateAccessKeyOutput, actualCreateAccessKeyError := m.CreateAccessKey(nil) + _, actualDeleteAccessKeyError := m.DeleteAccessKey(nil) + actualGetUserOutput, actualGetUserError := m.GetUser(nil) + assert.Equal(tc.expectedCreateAccessKeyOutput, actualCreateAccessKeyOutput) + assert.Equal(tc.expectedCreateAccessKeyError, actualCreateAccessKeyError) + assert.Equal(tc.expectedDeleteAccessKeyError, actualDeleteAccessKeyError) + assert.Equal(tc.expectedGetUserOutput, actualGetUserOutput) + assert.Equal(tc.expectedGetUserError, actualGetUserError) + }) + } +} + +func TestMockSTS(t *testing.T) { + cases := []struct { + name string + opts []MockSTSOption + expectedGetCallerIdentityOutput *sts.GetCallerIdentityOutput + expectedGetCallerIdentityError error + }{ + { + name: "GetCallerIdentityOutput", + opts: []MockSTSOption{WithGetCallerIdentityOutput( + &sts.GetCallerIdentityOutput{ + Account: aws.String("1234567890"), + Arn: aws.String("arn:aws:iam::123456789012:user/JohnDoe"), + UserId: aws.String("AIDAJQABLZS4A3QDU576Q"), + }, + )}, + expectedGetCallerIdentityOutput: &sts.GetCallerIdentityOutput{ + Account: aws.String("1234567890"), + Arn: aws.String("arn:aws:iam::123456789012:user/JohnDoe"), + UserId: aws.String("AIDAJQABLZS4A3QDU576Q"), + }, + }, + { + name: "GetCallerIdentityError", + opts: []MockSTSOption{WithGetCallerIdentityError(errors.New("testerr"))}, + expectedGetCallerIdentityError: errors.New("testerr"), + }, + } + + for _, tc := range cases { + tc := tc + t.Run(tc.name, func(t *testing.T) { + assert := assert.New(t) + require := require.New(t) + + f := NewMockSTS(tc.opts...) + m, err := f(nil) + require.NoError(err) // Nothing returns an error right now + actualGetCallerIdentityOutput, actualGetCallerIdentityError := m.GetCallerIdentity(nil) + assert.Equal(tc.expectedGetCallerIdentityOutput, actualGetCallerIdentityOutput) + assert.Equal(tc.expectedGetCallerIdentityError, actualGetCallerIdentityError) + }) + } +} diff --git a/awsutil/options.go b/awsutil/options.go index d3da48f..0d66a16 100644 --- a/awsutil/options.go +++ b/awsutil/options.go @@ -42,6 +42,8 @@ type options struct { withRegion string withHttpClient *http.Client withValidityCheckTimeout time.Duration + withIAMAPIFunc IAMAPIFunc + withSTSAPIFunc STSAPIFunc } func getDefaultOptions() options { @@ -170,3 +172,21 @@ func WithValidityCheckTimeout(with time.Duration) Option { return nil } } + +// WithIAMAPIFunc allows passing in an IAM interface constructor for mocking +// the AWS IAM API. +func WithIAMAPIFunc(with IAMAPIFunc) Option { + return func(o *options) error { + o.withIAMAPIFunc = with + return nil + } +} + +// WithSTSAPIFunc allows passing in a STS interface constructor for mocking the +// AWS STS API. +func WithSTSAPIFunc(with STSAPIFunc) Option { + return func(o *options) error { + o.withSTSAPIFunc = with + return nil + } +} diff --git a/awsutil/options_test.go b/awsutil/options_test.go index 65fe352..9e3169e 100644 --- a/awsutil/options_test.go +++ b/awsutil/options_test.go @@ -118,4 +118,14 @@ func Test_GetOpts(t *testing.T) { require.NoError(t, err) assert.Equal(t, opts.withValidityCheckTimeout, time.Second) }) + t.Run("withIAMIface", func(t *testing.T) { + opts, err := getOpts(WithIAMAPIFunc(NewMockIAM())) + require.NoError(t, err) + assert.NotNil(t, opts.withIAMAPIFunc) + }) + t.Run("withSTSIface", func(t *testing.T) { + opts, err := getOpts(WithSTSAPIFunc(NewMockSTS())) + require.NoError(t, err) + assert.NotNil(t, opts.withSTSAPIFunc) + }) } diff --git a/awsutil/rotate.go b/awsutil/rotate.go index 0fbc7d2..a92fd62 100644 --- a/awsutil/rotate.go +++ b/awsutil/rotate.go @@ -23,9 +23,12 @@ import ( // if the old one could not be deleted. // // Supported options: WithEnvironmentCredentials, WithSharedCredentials, -// WithAwsSession, WithUsername, WithValidityCheckTimeout. Note that WithValidityCheckTimeout -// here, when non-zero, controls the WithValidityCheckTimeout option on access key -// creation. See CreateAccessKey for more details. +// WithAwsSession, WithUsername, WithValidityCheckTimeout, WithIAMAPIFunc, +// WithSTSAPIFunc +// +// Note that WithValidityCheckTimeout here, when non-zero, controls the +// WithValidityCheckTimeout option on access key creation. See CreateAccessKey +// for more details. func (c *CredentialsConfig) RotateKeys(opt ...Option) error { if c.AccessKey == "" || c.SecretKey == "" { return errors.New("cannot rotate credentials when either access_key or secret_key is empty") @@ -64,7 +67,8 @@ func (c *CredentialsConfig) RotateKeys(opt ...Option) error { // CreateAccessKey creates a new access/secret key pair. // // Supported options: WithEnvironmentCredentials, WithSharedCredentials, -// WithAwsSession, WithUsername, WithValidityCheckTimeout +// WithAwsSession, WithUsername, WithValidityCheckTimeout, WithIAMAPIFunc, +// WithSTSAPIFunc // // When WithValidityCheckTimeout is non-zero, it specifies a timeout to wait on // the created credentials to be valid and ready for use. @@ -74,17 +78,9 @@ func (c *CredentialsConfig) CreateAccessKey(opt ...Option) (*iam.CreateAccessKey return nil, fmt.Errorf("error reading options in CreateAccessKey: %w", err) } - sess := opts.withAwsSession - if sess == nil { - sess, err = c.GetSession(opt...) - if err != nil { - return nil, fmt.Errorf("error calling GetSession: %w", err) - } - } - - client := iam.New(sess) - if client == nil { - return nil, errors.New("could not obtain iam client from session") + client, err := c.IAMClient(opt...) + if err != nil { + return nil, fmt.Errorf("error loading IAM client: %w", err) } var getUserInput iam.GetUserInput @@ -112,9 +108,12 @@ func (c *CredentialsConfig) CreateAccessKey(opt ...Option) (*iam.CreateAccessKey if err != nil { return nil, fmt.Errorf("error calling aws.CreateAccessKey: %w", err) } - if createAccessKeyRes.AccessKey == nil { + if createAccessKeyRes == nil { return nil, fmt.Errorf("nil response from aws.CreateAccessKey") } + if createAccessKeyRes.AccessKey == nil { + return nil, fmt.Errorf("nil access key in response from aws.CreateAccessKey") + } if createAccessKeyRes.AccessKey.AccessKeyId == nil || createAccessKeyRes.AccessKey.SecretAccessKey == nil { return nil, fmt.Errorf("nil AccessKeyId or SecretAccessKey returned from aws.CreateAccessKey") } @@ -128,7 +127,10 @@ func (c *CredentialsConfig) CreateAccessKey(opt ...Option) (*iam.CreateAccessKey SecretKey: *createAccessKeyRes.AccessKey.SecretAccessKey, } - if _, err := newC.GetCallerIdentity(WithValidityCheckTimeout(opts.withValidityCheckTimeout)); err != nil { + if _, err := newC.GetCallerIdentity( + WithValidityCheckTimeout(opts.withValidityCheckTimeout), + WithSTSAPIFunc(opts.withSTSAPIFunc), + ); err != nil { return nil, fmt.Errorf("error verifying new credentials: %w", err) } } @@ -139,24 +141,16 @@ func (c *CredentialsConfig) CreateAccessKey(opt ...Option) (*iam.CreateAccessKey // DeleteAccessKey deletes an access key. // // Supported options: WithEnvironmentCredentials, WithSharedCredentials, -// WithAwsSession, WithUserName +// WithAwsSession, WithUserName, WithIAMAPIFunc func (c *CredentialsConfig) DeleteAccessKey(accessKeyId string, opt ...Option) error { opts, err := getOpts(opt...) if err != nil { return fmt.Errorf("error reading options in RotateKeys: %w", err) } - sess := opts.withAwsSession - if sess == nil { - sess, err = c.GetSession(opt...) - if err != nil { - return fmt.Errorf("error calling GetSession: %w", err) - } - } - - client := iam.New(sess) - if client == nil { - return errors.New("could not obtain iam client from session") + client, err := c.IAMClient(opt...) + if err != nil { + return fmt.Errorf("error loading IAM client: %w", err) } deleteAccessKeyInput := iam.DeleteAccessKeyInput{ @@ -230,17 +224,9 @@ func (c *CredentialsConfig) GetCallerIdentity(opt ...Option) (*sts.GetCallerIden return nil, fmt.Errorf("error reading options in GetCallerIdentity: %w", err) } - sess := opts.withAwsSession - if sess == nil { - sess, err = c.GetSession(opt...) - if err != nil { - return nil, fmt.Errorf("error calling GetSession: %w", err) - } - } - - client := sts.New(sess) - if client == nil { - return nil, errors.New("could not obtain STS client from session") + client, err := c.STSClient(opt...) + if err != nil { + return nil, fmt.Errorf("error loading STS client: %w", err) } delay := time.Second diff --git a/awsutil/rotate_test.go b/awsutil/rotate_test.go index 858b34b..557384d 100644 --- a/awsutil/rotate_test.go +++ b/awsutil/rotate_test.go @@ -7,7 +7,10 @@ import ( "testing" "time" + "github.com/aws/aws-sdk-go/aws" "github.com/aws/aws-sdk-go/aws/awserr" + "github.com/aws/aws-sdk-go/service/iam" + "github.com/aws/aws-sdk-go/service/sts" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -131,3 +134,203 @@ func TestCallerIdentityErrorWithValidityCheckTimeout(t *testing.T) { require.NotNil(err) require.Implements((*awserr.Error)(nil), err) } + +func TestCallerIdentityWithSTSMockError(t *testing.T) { + require := require.New(t) + + expectedErr := errors.New("this is the expected error") + c, err := NewCredentialsConfig() + require.NoError(err) + _, err = c.GetCallerIdentity(WithSTSAPIFunc(NewMockSTS(WithGetCallerIdentityError(expectedErr)))) + require.EqualError(err, expectedErr.Error()) +} + +func TestCallerIdentityWithSTSMockNoErorr(t *testing.T) { + require := require.New(t) + + expectedOut := &sts.GetCallerIdentityOutput{ + Account: aws.String("1234567890"), + Arn: aws.String("arn:aws:iam::123456789012:user/JohnDoe"), + UserId: aws.String("AIDAJQABLZS4A3QDU576Q"), + } + + c, err := NewCredentialsConfig() + require.NoError(err) + out, err := c.GetCallerIdentity(WithSTSAPIFunc(NewMockSTS(WithGetCallerIdentityOutput(expectedOut)))) + require.NoError(err) + require.Equal(expectedOut, out) +} + +func TestDeleteAccessKeyWithIAMMock(t *testing.T) { + require := require.New(t) + + mockErr := errors.New("this is the expected error") + expectedErr := "error deleting old access key: this is the expected error" + c, err := NewCredentialsConfig() + require.NoError(err) + err = c.DeleteAccessKey("foobar", WithIAMAPIFunc(NewMockIAM(WithDeleteAccessKeyError(mockErr)))) + require.EqualError(err, expectedErr) +} + +func TestCreateAccessKeyWithIAMMockGetUserError(t *testing.T) { + require := require.New(t) + + mockErr := errors.New("this is the expected error") + expectedErr := "error calling aws.GetUser: this is the expected error" + c, err := NewCredentialsConfig() + require.NoError(err) + _, err = c.CreateAccessKey(WithIAMAPIFunc(NewMockIAM(WithGetUserError(mockErr)))) + require.EqualError(err, expectedErr) +} + +func TestCreateAccessKeyWithIAMMockCreateAccessKeyError(t *testing.T) { + require := require.New(t) + + mockErr := errors.New("this is the expected error") + expectedErr := "error calling aws.CreateAccessKey: this is the expected error" + c, err := NewCredentialsConfig() + require.NoError(err) + _, err = c.CreateAccessKey(WithIAMAPIFunc(NewMockIAM( + WithGetUserOutput(&iam.GetUserOutput{ + User: &iam.User{ + UserName: aws.String("foobar"), + }, + }), + WithCreateAccessKeyError(mockErr), + ))) + require.EqualError(err, expectedErr) +} + +func TestCreateAccessKeyWithIAMAndSTSMockGetCallerIdentityError(t *testing.T) { + require := require.New(t) + + mockErr := errors.New("this is the expected error") + expectedErr := "error verifying new credentials: timeout after 1ns waiting for success: this is the expected error" + c, err := NewCredentialsConfig() + require.NoError(err) + _, err = c.CreateAccessKey( + WithValidityCheckTimeout(time.Nanosecond), + WithIAMAPIFunc(NewMockIAM( + WithGetUserOutput(&iam.GetUserOutput{ + User: &iam.User{ + UserName: aws.String("foobar"), + }, + }), + WithCreateAccessKeyOutput(&iam.CreateAccessKeyOutput{ + AccessKey: &iam.AccessKey{ + AccessKeyId: aws.String("foobar"), + SecretAccessKey: aws.String("bazqux"), + }, + }), + )), + WithSTSAPIFunc(NewMockSTS( + WithGetCallerIdentityError(mockErr), + )), + ) + require.EqualError(err, expectedErr) +} + +func TestCreateAccessKeyNilResponse(t *testing.T) { + require := require.New(t) + + expectedErr := "nil response from aws.CreateAccessKey" + c, err := NewCredentialsConfig() + require.NoError(err) + _, err = c.CreateAccessKey( + WithValidityCheckTimeout(time.Nanosecond), + WithIAMAPIFunc(NewMockIAM( + WithGetUserOutput(&iam.GetUserOutput{ + User: &iam.User{ + UserName: aws.String("foobar"), + }, + }), + )), + ) + require.EqualError(err, expectedErr) +} + +func TestRotateKeysWithMocks(t *testing.T) { + mockErr := errors.New("this is the expected error") + cases := []struct { + name string + mockIAMOpts []MockIAMOption + mockSTSOpts []MockSTSOption + require func(t *testing.T, actual *CredentialsConfig) + requireErr string + }{ + { + name: "CreateAccessKey IAM error", + mockIAMOpts: []MockIAMOption{WithGetUserError(mockErr)}, + requireErr: "error calling CreateAccessKey: error calling aws.GetUser: this is the expected error", + }, + { + name: "CreateAccessKey STS error", + mockIAMOpts: []MockIAMOption{ + WithGetUserOutput(&iam.GetUserOutput{ + User: &iam.User{ + UserName: aws.String("foobar"), + }, + }), + WithCreateAccessKeyOutput(&iam.CreateAccessKeyOutput{ + AccessKey: &iam.AccessKey{ + AccessKeyId: aws.String("foobar"), + SecretAccessKey: aws.String("bazqux"), + }, + }), + }, + mockSTSOpts: []MockSTSOption{WithGetCallerIdentityError(mockErr)}, + requireErr: "error calling CreateAccessKey: error verifying new credentials: timeout after 1ns waiting for success: this is the expected error", + }, + { + name: "DeleteAccessKey IAM error", + mockIAMOpts: []MockIAMOption{ + WithGetUserOutput(&iam.GetUserOutput{ + User: &iam.User{ + UserName: aws.String("foobar"), + }, + }), + WithCreateAccessKeyOutput(&iam.CreateAccessKeyOutput{ + AccessKey: &iam.AccessKey{ + AccessKeyId: aws.String("foobar"), + SecretAccessKey: aws.String("bazqux"), + UserName: aws.String("foouser"), + }, + }), + // DeleteAccessKeyOutput w/o error is a no-op in the mock and + // will return without additional stubbing + }, + mockSTSOpts: []MockSTSOption{WithGetCallerIdentityOutput(&sts.GetCallerIdentityOutput{})}, + require: func(t *testing.T, actual *CredentialsConfig) { + t.Helper() + require := require.New(t) + + require.Equal("foobar", actual.AccessKey) + require.Equal("bazqux", actual.SecretKey) + }, + }, + } + + for _, tc := range cases { + tc := tc + t.Run(tc.name, func(t *testing.T) { + require := require.New(t) + c, err := NewCredentialsConfig( + WithAccessKey("foo"), + WithSecretKey("bar"), + ) + require.NoError(err) + err = c.RotateKeys( + WithIAMAPIFunc(NewMockIAM(tc.mockIAMOpts...)), + WithSTSAPIFunc(NewMockSTS(tc.mockSTSOpts...)), + WithValidityCheckTimeout(time.Nanosecond), + ) + if tc.requireErr != "" { + require.EqualError(err, tc.requireErr) + return + } + + require.NoError(err) + tc.require(t, c) + }) + } +}