Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix AWS auth renewals #8991

Merged
merged 6 commits into from
May 18, 2020
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 4 additions & 5 deletions builtin/credential/aws/backend.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ import (
"time"

"github.com/aws/aws-sdk-go/aws/endpoints"
"github.com/aws/aws-sdk-go/service/ec2"
"github.com/aws/aws-sdk-go/service/iam"
"github.com/hashicorp/vault/sdk/framework"
"github.com/hashicorp/vault/sdk/helper/awsutil"
Expand Down Expand Up @@ -57,13 +56,13 @@ type backend struct {
// This avoids the overhead of creating a client object for every login request.
// When the credentials are modified or deleted, all the cached client objects
// will be flushed. The empty STS role signifies the master account
EC2ClientsMap map[string]map[string]*ec2.EC2
EC2ClientsMap map[string]map[string]ec2Client

// Map to hold the IAM client objects indexed by region and STS role.
// This avoids the overhead of creating a client object for every login request.
// When the credentials are modified or deleted, all the cached client objects
// will be flushed. The empty STS role signifies the master account
IAMClientsMap map[string]map[string]*iam.IAM
IAMClientsMap map[string]map[string]iamClient

// Map to associate a partition to a random region in that partition. Users of
// this don't care what region in the partition they use, but there is some client
Expand Down Expand Up @@ -97,8 +96,8 @@ func Backend(_ *logical.BackendConfig) (*backend, error) {
// Setting the periodic func to be run once in an hour.
// If there is a real need, this can be made configurable.
tidyCooldownPeriod: time.Hour,
EC2ClientsMap: make(map[string]map[string]*ec2.EC2),
IAMClientsMap: make(map[string]map[string]*iam.IAM),
EC2ClientsMap: make(map[string]map[string]ec2Client),
IAMClientsMap: make(map[string]map[string]iamClient),
iamUserIdToArnCache: cache.New(7*24*time.Hour, 24*time.Hour),
tidyBlacklistCASGuard: new(uint32),
tidyWhitelistCASGuard: new(uint32),
Expand Down
2 changes: 1 addition & 1 deletion builtin/credential/aws/backend_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1488,7 +1488,7 @@ func TestBackendAcc_LoginWithCallerIdentity(t *testing.T) {
return
}

stsService := sts.New(awsSession)
stsService := newSTSClient(awsSession)
stsInputParams := &sts.GetCallerIdentityInput{}

testIdentity, err := stsService.GetCallerIdentity(stsInputParams)
Expand Down
2 changes: 1 addition & 1 deletion builtin/credential/aws/cli.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ func GenerateLoginData(creds *credentials.Credentials, headerValue, configuredRe
}

var params *sts.GetCallerIdentityInput
svc := sts.New(stsSession)
svc := newSTSClient(stsSession)
stsRequest, _ := svc.GetCallerIdentityRequest(params)

// Inject the required auth header value, if supplied, and then sign the request including that header
Expand Down
93 changes: 86 additions & 7 deletions builtin/credential/aws/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (

"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/credentials/stscreds"
"github.com/aws/aws-sdk-go/aws/request"
"github.com/aws/aws-sdk-go/aws/session"
"github.com/aws/aws-sdk-go/service/ec2"
"github.com/aws/aws-sdk-go/service/iam"
Expand All @@ -16,6 +17,30 @@ import (
"github.com/hashicorp/vault/sdk/logical"
)

var (
// These variables are intended to be set by tests. If set, the given
// client will override the AWS client, allowing client responses to
// be mocked out.
mockEC2Client ec2Client = nil
mockIAMClient iamClient = nil
mockSTSClient stsClient = nil
)

type ec2Client interface {
DescribeInstances(*ec2.DescribeInstancesInput) (*ec2.DescribeInstancesOutput, error)
}

type iamClient interface {
GetInstanceProfile(*iam.GetInstanceProfileInput) (*iam.GetInstanceProfileOutput, error)
GetRole(*iam.GetRoleInput) (*iam.GetRoleOutput, error)
GetUser(*iam.GetUserInput) (*iam.GetUserOutput, error)
}

type stsClient interface {
GetCallerIdentity(*sts.GetCallerIdentityInput) (*sts.GetCallerIdentityOutput, error)
GetCallerIdentityRequest(*sts.GetCallerIdentityInput) (req *request.Request, output *sts.GetCallerIdentityOutput)
}

// getRawClientConfig creates a aws-sdk-go config, which is used to create client
// that can interact with AWS API. This builds credentials in the following
// order of preference:
Expand Down Expand Up @@ -115,7 +140,7 @@ func (b *backend) getClientConfig(ctx context.Context, s logical.Storage, region
if err != nil {
return nil, err
}
client := sts.New(sess)
client := newSTSClient(sess)
if client == nil {
return nil, errwrap.Wrapf("could not obtain sts client: {{err}}", err)
}
Expand Down Expand Up @@ -192,7 +217,7 @@ func (b *backend) stsRoleForAccount(ctx context.Context, s logical.Storage, acco
}

// clientEC2 creates a client to interact with AWS EC2 API
func (b *backend) clientEC2(ctx context.Context, s logical.Storage, region, accountID string) (*ec2.EC2, error) {
func (b *backend) clientEC2(ctx context.Context, s logical.Storage, region, accountID string) (ec2Client, error) {
stsRole, err := b.stsRoleForAccount(ctx, s, accountID)
if err != nil {
return nil, err
Expand Down Expand Up @@ -231,12 +256,12 @@ func (b *backend) clientEC2(ctx context.Context, s logical.Storage, region, acco
if err != nil {
return nil, err
}
client := ec2.New(sess)
client := newEC2Client(sess)
if client == nil {
return nil, fmt.Errorf("could not obtain ec2 client")
}
if _, ok := b.EC2ClientsMap[region]; !ok {
b.EC2ClientsMap[region] = map[string]*ec2.EC2{stsRole: client}
b.EC2ClientsMap[region] = map[string]ec2Client{stsRole: client}
} else {
b.EC2ClientsMap[region][stsRole] = client
}
Expand All @@ -245,7 +270,7 @@ func (b *backend) clientEC2(ctx context.Context, s logical.Storage, region, acco
}

// clientIAM creates a client to interact with AWS IAM API
func (b *backend) clientIAM(ctx context.Context, s logical.Storage, region, accountID string) (*iam.IAM, error) {
func (b *backend) clientIAM(ctx context.Context, s logical.Storage, region, accountID string) (iamClient, error) {
stsRole, err := b.stsRoleForAccount(ctx, s, accountID)
if err != nil {
return nil, err
Expand Down Expand Up @@ -291,14 +316,68 @@ func (b *backend) clientIAM(ctx context.Context, s logical.Storage, region, acco
if err != nil {
return nil, err
}
client := iam.New(sess)
client := newIAMClient(sess)
if client == nil {
return nil, fmt.Errorf("could not obtain iam client")
}
if _, ok := b.IAMClientsMap[region]; !ok {
b.IAMClientsMap[region] = map[string]*iam.IAM{stsRole: client}
b.IAMClientsMap[region] = map[string]iamClient{stsRole: client}
} else {
b.IAMClientsMap[region][stsRole] = client
}
return b.IAMClientsMap[region][stsRole], nil
}

// newEC2Client should be used instead of using ec2.New()
// directly because it allows us to mock out the EC2 client
// as needed for testing.
func newEC2Client(sess *session.Session) ec2Client {
if mockEC2Client != nil {
return &replacableEC2Client{
ec2Client: mockEC2Client,
}
}
return &replacableEC2Client{
ec2Client: ec2.New(sess),
}
}

type replacableEC2Client struct {
ec2Client
}

// newIAMClient should be used instead of using iam.New()
// directly because it allows us to mock out the IAM client
// as needed for testing.
func newIAMClient(sess *session.Session) iamClient {
if mockIAMClient != nil {
return &replacableIAMClient{
iamClient: mockIAMClient,
}
}
return &replacableIAMClient{
iamClient: iam.New(sess),
}
}

type replacableIAMClient struct {
iamClient
}

// newSTSClient should be used instead of using sts.New()
// directly because it allows us to mock out the STS client
// as needed for testing.
func newSTSClient(sess *session.Session) stsClient {
if mockSTSClient != nil {
return &replacableSTSClient{
stsClient: mockSTSClient,
}
}
return &replacableSTSClient{
stsClient: sts.New(sess),
}
}

type replacableSTSClient struct {
stsClient
}
88 changes: 56 additions & 32 deletions builtin/credential/aws/path_login.go
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ needs to be supplied along with 'identity' parameter.`,

// instanceIamRoleARN fetches the IAM role ARN associated with the given
// instance profile name
func (b *backend) instanceIamRoleARN(iamClient *iam.IAM, instanceProfileName string) (string, error) {
func (b *backend) instanceIamRoleARN(iamClient iamClient, instanceProfileName string) (string, error) {
if iamClient == nil {
return "", fmt.Errorf("nil iamClient")
}
Expand Down Expand Up @@ -842,6 +842,11 @@ func (b *backend) pathLoginUpdateEc2(ctx context.Context, req *logical.Request,
Alias: &logical.Alias{
Name: identityAlias,
},
InternalData: map[string]interface{}{
"instance_id": identityDocParsed.InstanceID,
"region": identityDocParsed.Region,
"account_id": identityDocParsed.AccountID,
},
}
roleEntry.PopulateTokenAuth(auth)
if err := identityConfigEntry.EC2AuthMetadataHandler.PopulateDesiredMetadata(auth, map[string]string{
Expand Down Expand Up @@ -963,9 +968,9 @@ func (b *backend) pathLoginRenew(ctx context.Context, req *logical.Request, data
}

func (b *backend) pathLoginRenewIam(ctx context.Context, req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
canonicalArn := req.Auth.Metadata["canonical_arn"]
if canonicalArn == "" {
return nil, fmt.Errorf("unable to retrieve canonical ARN from metadata during renewal")
canonicalArn, err := getMetadataValue(req.Auth, "canonical_arn")
if err != nil {
return nil, err
}

roleName := ""
Expand Down Expand Up @@ -996,16 +1001,19 @@ func (b *backend) pathLoginRenewIam(ctx context.Context, req *logical.Request, d
// renew existing tokens.
if roleEntry.InferredEntityType != "" {
if roleEntry.InferredEntityType == ec2EntityType {
instanceID, ok := req.Auth.Metadata["inferred_entity_id"]
if !ok {
return nil, fmt.Errorf("no inferred entity ID in auth metadata")
instanceID, err := getMetadataValue(req.Auth, "inferred_entity_id")
if err != nil {
return nil, err
}
instanceRegion, ok := req.Auth.Metadata["inferred_aws_region"]
if !ok {
return nil, fmt.Errorf("no inferred AWS region in auth metadata")
instanceRegion, err := getMetadataValue(req.Auth, "inferred_aws_region")
if err != nil {
return nil, err
}
_, err := b.validateInstance(ctx, req.Storage, instanceID, instanceRegion, req.Auth.Metadata["account_id"])
accountID, err := getMetadataValue(req.Auth, "account_id")
if err != nil {
return nil, err
}
if _, err := b.validateInstance(ctx, req.Storage, instanceID, instanceRegion, accountID); err != nil {
return nil, errwrap.Wrapf(fmt.Sprintf("failed to verify instance ID %q: {{err}}", instanceID), err)
}
} else {
Expand All @@ -1027,9 +1035,9 @@ func (b *backend) pathLoginRenewIam(ctx context.Context, req *logical.Request, d
// implies that roleEntry.ResolveAWSUniqueIDs is true)
// 2: roleEntry.ResolveAWSUniqueIDs is false and canonical_arn is in roleEntry.BoundIamPrincipalARNs
// 3: Full ARN matches one of the wildcard globs in roleEntry.BoundIamPrincipalARNs
clientUserId, ok := req.Auth.Metadata["client_user_id"]
clientUserId, err := getMetadataValue(req.Auth, "client_user_id")
switch {
case ok && strutil.StrListContains(roleEntry.BoundIamPrincipalIDs, clientUserId): // check 1 passed
case err == nil && strutil.StrListContains(roleEntry.BoundIamPrincipalIDs, clientUserId): // check 1 passed
case !roleEntry.ResolveAWSUniqueIDs && strutil.StrListContains(roleEntry.BoundIamPrincipalARNs, canonicalArn): // check 2 passed
default:
// check 3 is a bit more complex, so we do it last
Expand Down Expand Up @@ -1070,28 +1078,22 @@ func (b *backend) pathLoginRenewIam(ctx context.Context, req *logical.Request, d
return resp, nil
}

func (b *backend) pathLoginRenewEc2(ctx context.Context, req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
instanceID := req.Auth.Metadata["instance_id"]
if instanceID == "" {
return nil, fmt.Errorf("unable to fetch instance ID from metadata during renewal")
func (b *backend) pathLoginRenewEc2(ctx context.Context, req *logical.Request, _ *framework.FieldData) (*logical.Response, error) {
instanceID, err := getMetadataValue(req.Auth, "instance_id")
if err != nil {
return nil, err
}

region := req.Auth.Metadata["region"]
if region == "" {
return nil, fmt.Errorf("unable to fetch region from metadata during renewal")
region, err := getMetadataValue(req.Auth, "region")
if err != nil {
return nil, err
}

// Ensure backwards compatibility for older clients without account_id saved in metadata
accountID, ok := req.Auth.Metadata["account_id"]
if ok {
if accountID == "" {
return nil, fmt.Errorf("unable to fetch account_id from metadata during renewal")
}
accountID, err := getMetadataValue(req.Auth, "account_id")
if err != nil {
return nil, err
}

// Cross check that the instance is still in 'running' state
_, err := b.validateInstance(ctx, req.Storage, instanceID, region, accountID)
if err != nil {
if _, err := b.validateInstance(ctx, req.Storage, instanceID, region, accountID); err != nil {
return nil, errwrap.Wrapf(fmt.Sprintf("failed to verify instance ID %q: {{err}}", instanceID), err)
}

Expand Down Expand Up @@ -1360,8 +1362,13 @@ func (b *backend) pathLoginUpdateIam(ctx context.Context, req *logical.Request,
"role_id": roleEntry.RoleID,
},
InternalData: map[string]interface{}{
"role_name": roleName,
"role_id": roleEntry.RoleID,
"role_name": roleName,
"role_id": roleEntry.RoleID,
"canonical_arn": entity.canonicalArn(),
"client_user_id": callerUniqueId,
"inferred_entity_id": inferredEntityID,
"inferred_aws_region": roleEntry.InferredAWSRegion,
"account_id": entity.AccountNumber,
},
DisplayName: entity.FriendlyName,
Alias: &logical.Alias{
Expand Down Expand Up @@ -1709,6 +1716,23 @@ func (b *backend) fullArn(ctx context.Context, e *iamEntity, s logical.Storage)
}
}

// getMetadataValue attempts to get a metadata key from
// auth.InternalData and if unset, auth.Metadata. If not
// found, returns "".
func getMetadataValue(fromAuth *logical.Auth, forKey string) (string, error) {
if raw, ok := fromAuth.InternalData[forKey]; ok {
if val, ok := raw.(string); ok {
return val, nil
} else {
return "", fmt.Errorf("unable to fetch %q from auth metadata due to type of %T", forKey, raw)
}
}
if val, ok := fromAuth.Metadata[forKey]; ok {
return val, nil
}
return "", fmt.Errorf("%q unfound in auth metadata", forKey)
}

const iamServerIdHeader = "X-Vault-AWS-IAM-Server-ID"

const pathLoginSyn = `
Expand Down
2 changes: 1 addition & 1 deletion builtin/credential/aws/path_login_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -515,7 +515,7 @@ func defaultLoginData() (map[string]interface{}, error) {
return nil, fmt.Errorf("failed to create session: %s", err)
}

stsService := sts.New(awsSession)
stsService := newSTSClient(awsSession)
stsInputParams := &sts.GetCallerIdentityInput{}
stsRequestValid, _ := stsService.GetCallerIdentityRequest(stsInputParams)
stsRequestValid.HTTPRequest.Header.Add(iamServerIdHeader, testVaultHeaderValue)
Expand Down
Loading