-
Notifications
You must be signed in to change notification settings - Fork 24
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
awsutil: add ability to mock IAM and STS APIs (#12)
* awsutil: add ability to mock IAM and STS APIs This adds the ability to supply mock IAM and STS interfaces to use with the RotateKeys, CreateAccessKey, DeleteAccessKey, and GetCallerIdentity methods. Additionally, work has been done to incorporate the existing IAM mock object into the package, in addition to adding a similar STS mock that allows for the mocking of the GetCallerIdentity method. Factory functions are also included to allow these to be incorporated seamlessly into the call path, in addition to allowing for introspection into any session data being received by the constructor functions. Also adds MockOptionErr to expose mocking out an options error.
- Loading branch information
1 parent
5b6393a
commit f7bda98
Showing
8 changed files
with
767 additions
and
42 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,90 @@ | ||
package awsutil | ||
|
||
import ( | ||
"errors" | ||
"fmt" | ||
|
||
"github.com/aws/aws-sdk-go/aws/session" | ||
"github.com/aws/aws-sdk-go/service/iam" | ||
"github.com/aws/aws-sdk-go/service/iam/iamiface" | ||
"github.com/aws/aws-sdk-go/service/sts" | ||
"github.com/aws/aws-sdk-go/service/sts/stsiface" | ||
) | ||
|
||
// IAMAPIFunc is a factory function for returning an IAM interface, | ||
// useful for supplying mock interfaces for testing IAM. The session | ||
// is passed into the function in the same way as done with the | ||
// standard iam.New() constructor. | ||
type IAMAPIFunc func(sess *session.Session) (iamiface.IAMAPI, error) | ||
|
||
// STSAPIFunc is a factory function for returning a STS interface, | ||
// useful for supplying mock interfaces for testing STS. The session | ||
// is passed into the function in the same way as done with the | ||
// standard sts.New() constructor. | ||
type STSAPIFunc func(sess *session.Session) (stsiface.STSAPI, error) | ||
|
||
// IAMClient returns an IAM client. | ||
// | ||
// Supported options: WithSession, WithIAMAPIFunc. | ||
// | ||
// If WithIAMAPIFunc is supplied, the included function is used as | ||
// the IAM client constructor instead. This can be used for Mocking | ||
// the IAM API. | ||
func (c *CredentialsConfig) IAMClient(opt ...Option) (iamiface.IAMAPI, error) { | ||
opts, err := getOpts(opt...) | ||
if err != nil { | ||
return nil, fmt.Errorf("error reading options: %w", err) | ||
} | ||
|
||
sess := opts.withAwsSession | ||
if sess == nil { | ||
sess, err = c.GetSession(opt...) | ||
if err != nil { | ||
return nil, fmt.Errorf("error calling GetSession: %w", err) | ||
} | ||
} | ||
|
||
if opts.withIAMAPIFunc != nil { | ||
return opts.withIAMAPIFunc(sess) | ||
} | ||
|
||
client := iam.New(sess) | ||
if client == nil { | ||
return nil, errors.New("could not obtain iam client from session") | ||
} | ||
|
||
return client, nil | ||
} | ||
|
||
// STSClient returns a STS client. | ||
// | ||
// Supported options: WithSession, WithSTSAPIFunc. | ||
// | ||
// If WithSTSAPIFunc is supplied, the included function is used as | ||
// the STS client constructor instead. This can be used for Mocking | ||
// the STS API. | ||
func (c *CredentialsConfig) STSClient(opt ...Option) (stsiface.STSAPI, error) { | ||
opts, err := getOpts(opt...) | ||
if err != nil { | ||
return nil, fmt.Errorf("error reading options: %w", err) | ||
} | ||
|
||
sess := opts.withAwsSession | ||
if sess == nil { | ||
sess, err = c.GetSession(opt...) | ||
if err != nil { | ||
return nil, fmt.Errorf("error calling GetSession: %w", err) | ||
} | ||
} | ||
|
||
if opts.withSTSAPIFunc != nil { | ||
return opts.withSTSAPIFunc(sess) | ||
} | ||
|
||
client := sts.New(sess) | ||
if client == nil { | ||
return nil, errors.New("could not obtain sts client from session") | ||
} | ||
|
||
return client, nil | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,137 @@ | ||
package awsutil | ||
|
||
import ( | ||
"errors" | ||
"fmt" | ||
"testing" | ||
|
||
"github.com/aws/aws-sdk-go/service/iam" | ||
"github.com/aws/aws-sdk-go/service/iam/iamiface" | ||
"github.com/aws/aws-sdk-go/service/sts" | ||
"github.com/aws/aws-sdk-go/service/sts/stsiface" | ||
"github.com/stretchr/testify/require" | ||
) | ||
|
||
const testOptionErr = "test option error" | ||
const testBadClientType = "badclienttype" | ||
|
||
func testWithBadClientType(o *options) error { | ||
o.withClientType = testBadClientType | ||
return nil | ||
} | ||
|
||
func TestCredentialsConfigIAMClient(t *testing.T) { | ||
cases := []struct { | ||
name string | ||
credentialsConfig *CredentialsConfig | ||
opts []Option | ||
require func(t *testing.T, actual iamiface.IAMAPI) | ||
requireErr string | ||
}{ | ||
{ | ||
name: "options error", | ||
credentialsConfig: &CredentialsConfig{}, | ||
opts: []Option{MockOptionErr(errors.New(testOptionErr))}, | ||
requireErr: fmt.Sprintf("error reading options: %s", testOptionErr), | ||
}, | ||
{ | ||
name: "session error", | ||
credentialsConfig: &CredentialsConfig{}, | ||
opts: []Option{testWithBadClientType}, | ||
requireErr: fmt.Sprintf("error calling GetSession: unknown client type %q in GetSession", testBadClientType), | ||
}, | ||
{ | ||
name: "with mock IAM session", | ||
credentialsConfig: &CredentialsConfig{}, | ||
opts: []Option{WithIAMAPIFunc(NewMockIAM())}, | ||
require: func(t *testing.T, actual iamiface.IAMAPI) { | ||
t.Helper() | ||
require := require.New(t) | ||
require.Equal(&MockIAM{}, actual) | ||
}, | ||
}, | ||
{ | ||
name: "no mock client", | ||
credentialsConfig: &CredentialsConfig{}, | ||
opts: []Option{}, | ||
require: func(t *testing.T, actual iamiface.IAMAPI) { | ||
t.Helper() | ||
require := require.New(t) | ||
require.IsType(&iam.IAM{}, actual) | ||
}, | ||
}, | ||
} | ||
|
||
for _, tc := range cases { | ||
tc := tc | ||
t.Run(tc.name, func(t *testing.T) { | ||
require := require.New(t) | ||
actual, err := tc.credentialsConfig.IAMClient(tc.opts...) | ||
if tc.requireErr != "" { | ||
require.EqualError(err, tc.requireErr) | ||
return | ||
} | ||
|
||
require.NoError(err) | ||
tc.require(t, actual) | ||
}) | ||
} | ||
} | ||
|
||
func TestCredentialsConfigSTSClient(t *testing.T) { | ||
cases := []struct { | ||
name string | ||
credentialsConfig *CredentialsConfig | ||
opts []Option | ||
require func(t *testing.T, actual stsiface.STSAPI) | ||
requireErr string | ||
}{ | ||
{ | ||
name: "options error", | ||
credentialsConfig: &CredentialsConfig{}, | ||
opts: []Option{MockOptionErr(errors.New(testOptionErr))}, | ||
requireErr: fmt.Sprintf("error reading options: %s", testOptionErr), | ||
}, | ||
{ | ||
name: "session error", | ||
credentialsConfig: &CredentialsConfig{}, | ||
opts: []Option{testWithBadClientType}, | ||
requireErr: fmt.Sprintf("error calling GetSession: unknown client type %q in GetSession", testBadClientType), | ||
}, | ||
{ | ||
name: "with mock STS session", | ||
credentialsConfig: &CredentialsConfig{}, | ||
opts: []Option{WithSTSAPIFunc(NewMockSTS())}, | ||
require: func(t *testing.T, actual stsiface.STSAPI) { | ||
t.Helper() | ||
require := require.New(t) | ||
require.Equal(&MockSTS{}, actual) | ||
}, | ||
}, | ||
{ | ||
name: "no mock client", | ||
credentialsConfig: &CredentialsConfig{}, | ||
opts: []Option{}, | ||
require: func(t *testing.T, actual stsiface.STSAPI) { | ||
t.Helper() | ||
require := require.New(t) | ||
require.IsType(&sts.STS{}, actual) | ||
}, | ||
}, | ||
} | ||
|
||
for _, tc := range cases { | ||
tc := tc | ||
t.Run(tc.name, func(t *testing.T) { | ||
require := require.New(t) | ||
actual, err := tc.credentialsConfig.STSClient(tc.opts...) | ||
if tc.requireErr != "" { | ||
require.EqualError(err, tc.requireErr) | ||
return | ||
} | ||
|
||
require.NoError(err) | ||
tc.require(t, actual) | ||
}) | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,26 +1,164 @@ | ||
package awsutil | ||
|
||
import ( | ||
"github.com/aws/aws-sdk-go/aws/session" | ||
"github.com/aws/aws-sdk-go/service/iam" | ||
"github.com/aws/aws-sdk-go/service/iam/iamiface" | ||
"github.com/aws/aws-sdk-go/service/sts" | ||
"github.com/aws/aws-sdk-go/service/sts/stsiface" | ||
) | ||
|
||
// MockOptionErr provides a mock option error for use with testing. | ||
func MockOptionErr(withErr error) Option { | ||
return func(_ *options) error { | ||
return withErr | ||
} | ||
} | ||
|
||
// MockIAM provides a way to mock the AWS IAM API. | ||
type MockIAM struct { | ||
iamiface.IAMAPI | ||
|
||
CreateAccessKeyOutput *iam.CreateAccessKeyOutput | ||
DeleteAccessKeyOutput *iam.DeleteAccessKeyOutput | ||
CreateAccessKeyError error | ||
DeleteAccessKeyError error | ||
GetUserOutput *iam.GetUserOutput | ||
GetUserError error | ||
} | ||
|
||
// MockIAMOption is a function for setting the various fields on a MockIAM | ||
// object. | ||
type MockIAMOption func(m *MockIAM) error | ||
|
||
// WithCreateAccessKeyOutput sets the output for the CreateAccessKey method. | ||
func WithCreateAccessKeyOutput(o *iam.CreateAccessKeyOutput) MockIAMOption { | ||
return func(m *MockIAM) error { | ||
m.CreateAccessKeyOutput = o | ||
return nil | ||
} | ||
} | ||
|
||
// WithCreateAccessKeyError sets the error output for the CreateAccessKey | ||
// method. | ||
func WithCreateAccessKeyError(e error) MockIAMOption { | ||
return func(m *MockIAM) error { | ||
m.CreateAccessKeyError = e | ||
return nil | ||
} | ||
} | ||
|
||
// WithDeleteAccessKeyError sets the error output for the DeleteAccessKey | ||
// method. | ||
func WithDeleteAccessKeyError(e error) MockIAMOption { | ||
return func(m *MockIAM) error { | ||
m.DeleteAccessKeyError = e | ||
return nil | ||
} | ||
} | ||
|
||
// WithGetUserOutput sets the output for the GetUser method. | ||
func WithGetUserOutput(o *iam.GetUserOutput) MockIAMOption { | ||
return func(m *MockIAM) error { | ||
m.GetUserOutput = o | ||
return nil | ||
} | ||
} | ||
|
||
// WithGetUserError sets the error output for the GetUser method. | ||
func WithGetUserError(e error) MockIAMOption { | ||
return func(m *MockIAM) error { | ||
m.GetUserError = e | ||
return nil | ||
} | ||
} | ||
|
||
// NewMockIAM provides a factory function to use with the WithIAMAPIFunc | ||
// option. | ||
func NewMockIAM(opts ...MockIAMOption) IAMAPIFunc { | ||
return func(_ *session.Session) (iamiface.IAMAPI, error) { | ||
m := new(MockIAM) | ||
for _, opt := range opts { | ||
if err := opt(m); err != nil { | ||
return nil, err | ||
} | ||
} | ||
|
||
return m, nil | ||
} | ||
} | ||
|
||
func (m *MockIAM) CreateAccessKey(*iam.CreateAccessKeyInput) (*iam.CreateAccessKeyOutput, error) { | ||
if m.CreateAccessKeyError != nil { | ||
return nil, m.CreateAccessKeyError | ||
} | ||
|
||
return m.CreateAccessKeyOutput, nil | ||
} | ||
|
||
func (m *MockIAM) DeleteAccessKey(*iam.DeleteAccessKeyInput) (*iam.DeleteAccessKeyOutput, error) { | ||
return m.DeleteAccessKeyOutput, nil | ||
return &iam.DeleteAccessKeyOutput{}, m.DeleteAccessKeyError | ||
} | ||
|
||
func (m *MockIAM) GetUser(*iam.GetUserInput) (*iam.GetUserOutput, error) { | ||
if m.GetUserError != nil { | ||
return nil, m.GetUserError | ||
} | ||
|
||
return m.GetUserOutput, nil | ||
} | ||
|
||
// MockSTS provides a way to mock the AWS STS API. | ||
type MockSTS struct { | ||
stsiface.STSAPI | ||
|
||
GetCallerIdentityOutput *sts.GetCallerIdentityOutput | ||
GetCallerIdentityError error | ||
} | ||
|
||
// MockSTSOption is a function for setting the various fields on a MockSTS | ||
// object. | ||
type MockSTSOption func(m *MockSTS) error | ||
|
||
// WithGetCallerIdentityOutput sets the output for the GetCallerIdentity | ||
// method. | ||
func WithGetCallerIdentityOutput(o *sts.GetCallerIdentityOutput) MockSTSOption { | ||
return func(m *MockSTS) error { | ||
m.GetCallerIdentityOutput = o | ||
return nil | ||
} | ||
} | ||
|
||
// WithGetCallerIdentityError sets the error output for the GetCallerIdentity | ||
// method. | ||
func WithGetCallerIdentityError(e error) MockSTSOption { | ||
return func(m *MockSTS) error { | ||
m.GetCallerIdentityError = e | ||
return nil | ||
} | ||
} | ||
|
||
// NewMockSTS provides a factory function to use with the WithSTSAPIFunc | ||
// option. | ||
// | ||
// If withGetCallerIdentityError is supplied, calls to GetCallerIdentity will | ||
// return the supplied error. Otherwise, a basic mock API output is returned. | ||
func NewMockSTS(opts ...MockSTSOption) STSAPIFunc { | ||
return func(_ *session.Session) (stsiface.STSAPI, error) { | ||
m := new(MockSTS) | ||
for _, opt := range opts { | ||
if err := opt(m); err != nil { | ||
return nil, err | ||
} | ||
} | ||
|
||
return m, nil | ||
} | ||
} | ||
|
||
func (m *MockSTS) GetCallerIdentity(_ *sts.GetCallerIdentityInput) (*sts.GetCallerIdentityOutput, error) { | ||
if m.GetCallerIdentityError != nil { | ||
return nil, m.GetCallerIdentityError | ||
} | ||
|
||
return m.GetCallerIdentityOutput, nil | ||
} |
Oops, something went wrong.