Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

awsutil: add ability to mock IAM and STS APIs #12

Merged
merged 2 commits into from
Oct 6, 2021
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
90 changes: 90 additions & 0 deletions awsutil/clients.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
package awsutil

import (
"errors"
"fmt"

"github.com/aws/aws-sdk-go/aws/session"
"github.com/aws/aws-sdk-go/service/iam"
"github.com/aws/aws-sdk-go/service/iam/iamiface"
"github.com/aws/aws-sdk-go/service/sts"
"github.com/aws/aws-sdk-go/service/sts/stsiface"
)

// IAMAPIFunc is a factory function for returning an IAM interface,
// useful for supplying mock interfaces for testing IAM. The session
// is passed into the function in the same way as done with the
// standard iam.New() constructor.
type IAMAPIFunc func(sess *session.Session) (iamiface.IAMAPI, error)

// STSAPIFunc is a factory function for returning a STS interface,
// useful for supplying mock interfaces for testing STS. The session
// is passed into the function in the same way as done with the
// standard sts.New() constructor.
type STSAPIFunc func(sess *session.Session) (stsiface.STSAPI, error)

// IAMClient returns an IAM client.
//
// Supported options: WithSession, WithIAMAPIFunc.
//
// If WithIAMAPIFunc is supplied, the included function is used as
// the IAM client constructor instead. This can be used for Mocking
// the IAM API.
func (c *CredentialsConfig) IAMClient(opt ...Option) (iamiface.IAMAPI, error) {
opts, err := getOpts(opt...)
if err != nil {
return nil, fmt.Errorf("error reading options: %w", err)
}

sess := opts.withAwsSession
if sess == nil {
sess, err = c.GetSession(opt...)
if err != nil {
return nil, fmt.Errorf("error calling GetSession: %w", err)
}
}

if opts.withIAMAPIFunc != nil {
return opts.withIAMAPIFunc(sess)
}

client := iam.New(sess)
if client == nil {
return nil, errors.New("could not obtain iam client from session")
}

return client, nil
}

// STSClient returns a STS client.
//
// Supported options: WithSession, WithSTSAPIFunc.
//
// If WithSTSAPIFunc is supplied, the included function is used as
// the STS client constructor instead. This can be used for Mocking
// the STS API.
func (c *CredentialsConfig) STSClient(opt ...Option) (stsiface.STSAPI, error) {
opts, err := getOpts(opt...)
if err != nil {
return nil, fmt.Errorf("error reading options: %w", err)
}

sess := opts.withAwsSession
if sess == nil {
sess, err = c.GetSession(opt...)
if err != nil {
return nil, fmt.Errorf("error calling GetSession: %w", err)
}
}

if opts.withSTSAPIFunc != nil {
return opts.withSTSAPIFunc(sess)
}

client := sts.New(sess)
if client == nil {
return nil, errors.New("could not obtain sts client from session")
}

return client, nil
}
141 changes: 141 additions & 0 deletions awsutil/clients_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
package awsutil

import (
"errors"
"fmt"
"testing"

"github.com/aws/aws-sdk-go/service/iam"
"github.com/aws/aws-sdk-go/service/iam/iamiface"
"github.com/aws/aws-sdk-go/service/sts"
"github.com/aws/aws-sdk-go/service/sts/stsiface"
"github.com/stretchr/testify/require"
)

const testOptionErr = "test option error"
const testBadClientType = "badclienttype"

func testWithOptionError(_ *options) error {
return errors.New(testOptionErr)
}

func testWithBadClientType(o *options) error {
o.withClientType = testBadClientType
return nil
}

func TestCredentialsConfigIAMClient(t *testing.T) {
cases := []struct {
name string
credentialsConfig *CredentialsConfig
opts []Option
require func(t *testing.T, actual iamiface.IAMAPI)
requireErr string
}{
{
name: "options error",
credentialsConfig: &CredentialsConfig{},
opts: []Option{testWithOptionError},
requireErr: fmt.Sprintf("error reading options: %s", testOptionErr),
},
{
name: "session error",
credentialsConfig: &CredentialsConfig{},
opts: []Option{testWithBadClientType},
requireErr: fmt.Sprintf("error calling GetSession: unknown client type %q in GetSession", testBadClientType),
},
{
name: "with mock IAM session",
credentialsConfig: &CredentialsConfig{},
opts: []Option{WithIAMAPIFunc(NewMockIAM())},
require: func(t *testing.T, actual iamiface.IAMAPI) {
t.Helper()
require := require.New(t)
require.Equal(&MockIAM{}, actual)
},
},
{
name: "no mock client",
credentialsConfig: &CredentialsConfig{},
opts: []Option{},
require: func(t *testing.T, actual iamiface.IAMAPI) {
t.Helper()
require := require.New(t)
require.IsType(&iam.IAM{}, actual)
},
},
}

for _, tc := range cases {
tc := tc
t.Run(tc.name, func(t *testing.T) {
require := require.New(t)
actual, err := tc.credentialsConfig.IAMClient(tc.opts...)
if tc.requireErr != "" {
require.EqualError(err, tc.requireErr)
return
}

require.NoError(err)
tc.require(t, actual)
})
}
}

func TestCredentialsConfigSTSClient(t *testing.T) {
cases := []struct {
name string
credentialsConfig *CredentialsConfig
opts []Option
require func(t *testing.T, actual stsiface.STSAPI)
requireErr string
}{
{
name: "options error",
credentialsConfig: &CredentialsConfig{},
opts: []Option{testWithOptionError},
requireErr: fmt.Sprintf("error reading options: %s", testOptionErr),
},
{
name: "session error",
credentialsConfig: &CredentialsConfig{},
opts: []Option{testWithBadClientType},
requireErr: fmt.Sprintf("error calling GetSession: unknown client type %q in GetSession", testBadClientType),
},
{
name: "with mock STS session",
credentialsConfig: &CredentialsConfig{},
opts: []Option{WithSTSAPIFunc(NewMockSTS())},
require: func(t *testing.T, actual stsiface.STSAPI) {
t.Helper()
require := require.New(t)
require.Equal(&MockSTS{}, actual)
},
},
{
name: "no mock client",
credentialsConfig: &CredentialsConfig{},
opts: []Option{},
require: func(t *testing.T, actual stsiface.STSAPI) {
t.Helper()
require := require.New(t)
require.IsType(&sts.STS{}, actual)
},
},
}

for _, tc := range cases {
tc := tc
t.Run(tc.name, func(t *testing.T) {
require := require.New(t)
actual, err := tc.credentialsConfig.STSClient(tc.opts...)
if tc.requireErr != "" {
require.EqualError(err, tc.requireErr)
return
}

require.NoError(err)
tc.require(t, actual)
})
}
}
135 changes: 133 additions & 2 deletions awsutil/mocks.go
Original file line number Diff line number Diff line change
@@ -1,26 +1,157 @@
package awsutil

import (
"github.com/aws/aws-sdk-go/aws/session"
"github.com/aws/aws-sdk-go/service/iam"
"github.com/aws/aws-sdk-go/service/iam/iamiface"
"github.com/aws/aws-sdk-go/service/sts"
"github.com/aws/aws-sdk-go/service/sts/stsiface"
)

// MockIAM provides a way to mock the AWS IAM API.
type MockIAM struct {
iamiface.IAMAPI

CreateAccessKeyOutput *iam.CreateAccessKeyOutput
DeleteAccessKeyOutput *iam.DeleteAccessKeyOutput
CreateAccessKeyError error
DeleteAccessKeyError error
GetUserOutput *iam.GetUserOutput
GetUserError error
}

// MockIAMOption is a function for setting the various fields on a MockIAM
// object.
type MockIAMOption func(m *MockIAM) error

// WithCreateAccessKeyOutput sets the output for the CreateAccessKey method.
func WithCreateAccessKeyOutput(o *iam.CreateAccessKeyOutput) MockIAMOption {
return func(m *MockIAM) error {
m.CreateAccessKeyOutput = o
return nil
}
}

// WithCreateAccessKeyError sets the error output for the CreateAccessKey
// method.
func WithCreateAccessKeyError(e error) MockIAMOption {
return func(m *MockIAM) error {
m.CreateAccessKeyError = e
return nil
}
}

// WithDeleteAccessKeyError sets the error output for the DeleteAccessKey
// method.
func WithDeleteAccessKeyError(e error) MockIAMOption {
return func(m *MockIAM) error {
m.DeleteAccessKeyError = e
return nil
}
}

// WithGetUserOutput sets the output for the GetUser method.
func WithGetUserOutput(o *iam.GetUserOutput) MockIAMOption {
return func(m *MockIAM) error {
m.GetUserOutput = o
return nil
}
}

// WithGetUserError sets the error output for the GetUser method.
func WithGetUserError(e error) MockIAMOption {
return func(m *MockIAM) error {
m.GetUserError = e
return nil
}
}

// NewMockIAM provides a factory function to use with the WithIAMAPIFunc
// option.
func NewMockIAM(opts ...MockIAMOption) IAMAPIFunc {
return func(_ *session.Session) (iamiface.IAMAPI, error) {
m := new(MockIAM)
for _, opt := range opts {
if err := opt(m); err != nil {
return nil, err
}
}

return m, nil
}
}

func (m *MockIAM) CreateAccessKey(*iam.CreateAccessKeyInput) (*iam.CreateAccessKeyOutput, error) {
if m.CreateAccessKeyError != nil {
return nil, m.CreateAccessKeyError
}

return m.CreateAccessKeyOutput, nil
}

func (m *MockIAM) DeleteAccessKey(*iam.DeleteAccessKeyInput) (*iam.DeleteAccessKeyOutput, error) {
return m.DeleteAccessKeyOutput, nil
return &iam.DeleteAccessKeyOutput{}, m.DeleteAccessKeyError
}

func (m *MockIAM) GetUser(*iam.GetUserInput) (*iam.GetUserOutput, error) {
if m.GetUserError != nil {
return nil, m.GetUserError
}

return m.GetUserOutput, nil
}

// MockSTS provides a way to mock the AWS STS API.
type MockSTS struct {
stsiface.STSAPI

GetCallerIdentityOutput *sts.GetCallerIdentityOutput
GetCallerIdentityError error
}

// MockSTSOption is a function for setting the various fields on a MockSTS
// object.
type MockSTSOption func(m *MockSTS) error

// WithGetCallerIdentityOutput sets the output for the GetCallerIdentity
// method.
func WithGetCallerIdentityOutput(o *sts.GetCallerIdentityOutput) MockSTSOption {
return func(m *MockSTS) error {
m.GetCallerIdentityOutput = o
return nil
}
}

// WithGetCallerIdentityError sets the error output for the GetCallerIdentity
// method.
func WithGetCallerIdentityError(e error) MockSTSOption {
return func(m *MockSTS) error {
m.GetCallerIdentityError = e
return nil
}
}

// NewMockSTS provides a factory function to use with the WithSTSAPIFunc
// option.
//
// If withGetCallerIdentityError is supplied, calls to GetCallerIdentity will
// return the supplied error. Otherwise, a basic mock API output is returned.
func NewMockSTS(opts ...MockSTSOption) STSAPIFunc {
return func(_ *session.Session) (stsiface.STSAPI, error) {
m := new(MockSTS)
for _, opt := range opts {
if err := opt(m); err != nil {
return nil, err
}
}

return m, nil
}
}

func (m *MockSTS) GetCallerIdentity(_ *sts.GetCallerIdentityInput) (*sts.GetCallerIdentityOutput, error) {
if m.GetCallerIdentityError != nil {
return nil, m.GetCallerIdentityError
}

return m.GetCallerIdentityOutput, nil
}
Loading