Skip to content

Commit

Permalink
aws/credentials/stscreds: Add STS and Assume Role specific retries (#…
Browse files Browse the repository at this point in the history
…2752)

Adds retries to specific STS API errors to the STS AssumeRoleWithWebIdentity credential provider, and STS API operations in general.
  • Loading branch information
jasdel authored Aug 14, 2019
1 parent 0d2fb42 commit d8409ae
Show file tree
Hide file tree
Showing 5 changed files with 135 additions and 56 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG_PENDING.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,5 +4,7 @@
* `service/kinesis`: Add support for retrying service specific API errors ([#2751](https://github.com/aws/aws-sdk-go/pull/2751)
* Adds support for retrying the Kinesis API error, LimitExceededException.
* Fixes [#1376](https://github.com/aws/aws-sdk-go/issues/1376)
* `aws/credentials/stscreds`: Add STS and Assume Role specific retries ([#2752](https://github.com/aws/aws-sdk-go/pull/2752))
* Adds retries to specific STS API errors to the STS AssumeRoleWithWebIdentity credential provider, and STS API operations in general.

### SDK Bugs
7 changes: 5 additions & 2 deletions aws/credentials/stscreds/web_identity_provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -76,12 +76,15 @@ func (p *WebIdentityRoleProvider) Retrieve() (credentials.Value, error) {
// uses unix time in nanoseconds to uniquely identify sessions.
sessionName = strconv.FormatInt(now().UnixNano(), 10)
}
resp, err := p.client.AssumeRoleWithWebIdentity(&sts.AssumeRoleWithWebIdentityInput{
req, resp := p.client.AssumeRoleWithWebIdentityRequest(&sts.AssumeRoleWithWebIdentityInput{
RoleArn: &p.roleARN,
RoleSessionName: &sessionName,
WebIdentityToken: aws.String(string(b)),
})
if err != nil {
// InvalidIdentityToken error is a temporary error that can occur
// when assuming an Role with a JWT web identity token.
req.RetryErrorCodes = append(req.RetryErrorCodes, sts.ErrCodeInvalidIdentityTokenException)
if err := req.Send(); err != nil {
return credentials.Value{}, awserr.New(ErrCodeWebIdentity, "failed to retrieve credentials", err)
}

Expand Down
124 changes: 70 additions & 54 deletions aws/credentials/stscreds/web_identity_provider_test.go
Original file line number Diff line number Diff line change
@@ -1,39 +1,28 @@
// +build go1.7

package stscreds
package stscreds_test

import (
"fmt"
"net/http"
"reflect"
"testing"
"time"

"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/aws/corehandlers"
"github.com/aws/aws-sdk-go/aws/credentials"
"github.com/aws/aws-sdk-go/aws/credentials/stscreds"
"github.com/aws/aws-sdk-go/aws/request"
"github.com/aws/aws-sdk-go/awstesting/unit"
"github.com/aws/aws-sdk-go/service/sts"
)

type mockSTS struct {
*sts.STS
AssumeRoleWithWebIdentityFn func(input *sts.AssumeRoleWithWebIdentityInput) (*sts.AssumeRoleWithWebIdentityOutput, error)
}

func (m *mockSTS) AssumeRoleWithWebIdentity(input *sts.AssumeRoleWithWebIdentityInput) (*sts.AssumeRoleWithWebIdentityOutput, error) {
if m.AssumeRoleWithWebIdentityFn != nil {
return m.AssumeRoleWithWebIdentityFn(input)
}

return nil, nil
}

func TestWebIdentityProviderRetrieve(t *testing.T) {
now = func() time.Time {
return time.Time{}
}

var reqCount int
cases := []struct {
name string
mockSTS *mockSTS
onSendReq func(*testing.T, *request.Request)
roleARN string
tokenFilepath string
sessionName string
Expand All @@ -42,64 +31,91 @@ func TestWebIdentityProviderRetrieve(t *testing.T) {
}{
{
name: "session name case",
roleARN: "arn",
roleARN: "arn01234567890123456789",
tokenFilepath: "testdata/token.jwt",
sessionName: "foo",
mockSTS: &mockSTS{
AssumeRoleWithWebIdentityFn: func(input *sts.AssumeRoleWithWebIdentityInput) (*sts.AssumeRoleWithWebIdentityOutput, error) {
if e, a := "foo", *input.RoleSessionName; !reflect.DeepEqual(e, a) {
t.Errorf("expected %v, but received %v", e, a)
}
onSendReq: func(t *testing.T, r *request.Request) {
input := r.Params.(*sts.AssumeRoleWithWebIdentityInput)
if e, a := "foo", *input.RoleSessionName; !reflect.DeepEqual(e, a) {
t.Errorf("expected %v, but received %v", e, a)
}

return &sts.AssumeRoleWithWebIdentityOutput{
Credentials: &sts.Credentials{
Expiration: aws.Time(time.Now()),
AccessKeyId: aws.String("access-key-id"),
SecretAccessKey: aws.String("secret-access-key"),
SessionToken: aws.String("session-token"),
},
}, nil
},
data := r.Data.(*sts.AssumeRoleWithWebIdentityOutput)
*data = sts.AssumeRoleWithWebIdentityOutput{
Credentials: &sts.Credentials{
Expiration: aws.Time(time.Now()),
AccessKeyId: aws.String("access-key-id"),
SecretAccessKey: aws.String("secret-access-key"),
SessionToken: aws.String("session-token"),
},
}
},
expectedCredValue: credentials.Value{
AccessKeyID: "access-key-id",
SecretAccessKey: "secret-access-key",
SessionToken: "session-token",
ProviderName: WebIdentityProviderName,
ProviderName: stscreds.WebIdentityProviderName,
},
},
{
name: "valid case",
roleARN: "arn",
name: "invalid token retry",
roleARN: "arn01234567890123456789",
tokenFilepath: "testdata/token.jwt",
mockSTS: &mockSTS{
AssumeRoleWithWebIdentityFn: func(input *sts.AssumeRoleWithWebIdentityInput) (*sts.AssumeRoleWithWebIdentityOutput, error) {
if e, a := fmt.Sprintf("%d", now().UnixNano()), *input.RoleSessionName; !reflect.DeepEqual(e, a) {
t.Errorf("expected %v, but received %v", e, a)
}
sessionName: "foo",
onSendReq: func(t *testing.T, r *request.Request) {
input := r.Params.(*sts.AssumeRoleWithWebIdentityInput)
if e, a := "foo", *input.RoleSessionName; !reflect.DeepEqual(e, a) {
t.Errorf("expected %v, but received %v", e, a)
}

return &sts.AssumeRoleWithWebIdentityOutput{
Credentials: &sts.Credentials{
Expiration: aws.Time(time.Now()),
AccessKeyId: aws.String("access-key-id"),
SecretAccessKey: aws.String("secret-access-key"),
SessionToken: aws.String("session-token"),
},
}, nil
},
if reqCount == 0 {
r.HTTPResponse.StatusCode = 400
r.Error = awserr.New(sts.ErrCodeInvalidIdentityTokenException,
"some error message", nil)
return
}

data := r.Data.(*sts.AssumeRoleWithWebIdentityOutput)
*data = sts.AssumeRoleWithWebIdentityOutput{
Credentials: &sts.Credentials{
Expiration: aws.Time(time.Now()),
AccessKeyId: aws.String("access-key-id"),
SecretAccessKey: aws.String("secret-access-key"),
SessionToken: aws.String("session-token"),
},
}
},
expectedCredValue: credentials.Value{
AccessKeyID: "access-key-id",
SecretAccessKey: "secret-access-key",
SessionToken: "session-token",
ProviderName: WebIdentityProviderName,
ProviderName: stscreds.WebIdentityProviderName,
},
},
}

for _, c := range cases {
t.Run(c.name, func(t *testing.T) {
p := NewWebIdentityRoleProvider(c.mockSTS, c.roleARN, c.sessionName, c.tokenFilepath)
reqCount = 0

svc := sts.New(unit.Session, &aws.Config{
Logger: t,
})
svc.Handlers.Send.Swap(corehandlers.SendHandler.Name, request.NamedHandler{
Name: "custom send stub handler",
Fn: func(r *request.Request) {
r.HTTPResponse = &http.Response{
StatusCode: 200, Header: http.Header{},
}
c.onSendReq(t, r)
reqCount++
},
})
svc.Handlers.UnmarshalMeta.Clear()
svc.Handlers.Unmarshal.Clear()
svc.Handlers.UnmarshalError.Clear()

p := stscreds.NewWebIdentityRoleProvider(svc, c.roleARN, c.sessionName, c.tokenFilepath)
credValue, err := p.Retrieve()
if e, a := c.expectedError, err; !reflect.DeepEqual(e, a) {
t.Errorf("expected %v, but received %v", e, a)
Expand Down
11 changes: 11 additions & 0 deletions service/sts/customizations.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
package sts

import "github.com/aws/aws-sdk-go/aws/request"

func init() {
initRequest = customizeRequest
}

func customizeRequest(r *request.Request) {
r.RetryErrorCodes = append(r.RetryErrorCodes, ErrCodeIDPCommunicationErrorException)
}
47 changes: 47 additions & 0 deletions service/sts/customizations_test.go
Original file line number Diff line number Diff line change
@@ -1,9 +1,15 @@
package sts_test

import (
"bytes"
"fmt"
"io/ioutil"
"net/http"
"testing"

"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/corehandlers"
"github.com/aws/aws-sdk-go/aws/request"
"github.com/aws/aws-sdk-go/awstesting/unit"
"github.com/aws/aws-sdk-go/service/sts"
)
Expand Down Expand Up @@ -43,3 +49,44 @@ func TestUnsignedRequest_AssumeRoleWithWebIdentity(t *testing.T) {
t.Errorf("expect %v, got %v", e, a)
}
}

func TestSTSCustomRetryErrorCodes(t *testing.T) {
svc := sts.New(unit.Session, &aws.Config{
MaxRetries: aws.Int(1),
})
svc.Handlers.Validate.Clear()

const xmlErr = `<ErrorResponse><Error><Code>%s</Code><Message>some error message</Message></Error></ErrorResponse>`
var reqCount int
resps := []*http.Response{
{
StatusCode: 400,
Header: http.Header{},
Body: ioutil.NopCloser(bytes.NewReader(
[]byte(fmt.Sprintf(xmlErr, sts.ErrCodeIDPCommunicationErrorException)),
)),
},
{
StatusCode: 200,
Header: http.Header{},
Body: ioutil.NopCloser(bytes.NewReader([]byte{})),
},
}

req, _ := svc.AssumeRoleWithWebIdentityRequest(&sts.AssumeRoleWithWebIdentityInput{})
req.Handlers.Send.Swap(corehandlers.SendHandler.Name, request.NamedHandler{
Name: "custom send handler",
Fn: func(r *request.Request) {
r.HTTPResponse = resps[reqCount]
reqCount++
},
})

if err := req.Send(); err != nil {
t.Fatalf("expect no error, got %v", err)
}

if e, a := 2, reqCount; e != a {
t.Errorf("expect %v requests, got %v", e, a)
}
}

0 comments on commit d8409ae

Please sign in to comment.