Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Aws auth fixes #9825

Merged
merged 12 commits into from
Aug 25, 2020
9 changes: 9 additions & 0 deletions builtin/credential/aws/backend.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,15 @@ import (
cache "github.com/patrickmn/go-cache"
)

const amzHeaderPrefix = "X-Amz-"
var defaultAllowedSTSRequestHeaders = []string{
"X-Amz-Date",
"X-Amz-Credential",
"X-Amz-Security-Token",
"X-Amz-Algorithm",
"X-Amz-Signature",
"X-Amz-SignedHeaders"}

func Factory(ctx context.Context, conf *logical.BackendConfig) (logical.Backend, error) {
b, err := Backend(conf)
if err != nil {
Expand Down
58 changes: 50 additions & 8 deletions builtin/credential/aws/path_config_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,14 @@ package awsauth

import (
"context"
"errors"
"net/http"
"net/textproto"
"strings"

"github.com/aws/aws-sdk-go/aws"
"github.com/hashicorp/vault/sdk/framework"
"github.com/hashicorp/vault/sdk/helper/strutil"
"github.com/hashicorp/vault/sdk/logical"
)

Expand Down Expand Up @@ -53,6 +58,11 @@ func (b *backend) pathConfigClient() *framework.Path {
Default: "",
Description: "Value to require in the X-Vault-AWS-IAM-Server-ID request header",
},
"allowed_sts_header_values": {
Type: framework.TypeCommaStringSlice,
Default: nil,
Description: "List of additional headers that are allowed to be in AWS STS request headers",
},
"max_retries": {
Type: framework.TypeInt,
Default: aws.UseServiceDefaultRetries,
Expand Down Expand Up @@ -136,6 +146,7 @@ func (b *backend) pathConfigClientRead(ctx context.Context, req *logical.Request
"sts_region": clientConfig.STSRegion,
"iam_server_id_header_value": clientConfig.IAMServerIdHeaderValue,
"max_retries": clientConfig.MaxRetries,
"allowed_sts_header_values": clientConfig.AllowedSTSHeaderValues,
},
}, nil
}
Expand Down Expand Up @@ -257,6 +268,24 @@ func (b *backend) pathConfigClientCreateUpdate(ctx context.Context, req *logical
configEntry.IAMServerIdHeaderValue = data.Get("iam_server_id_header_value").(string)
}

aHeadersValStr, ok := data.GetOk("allowed_sts_header_values")
if ok {
aHeadersValSl := aHeadersValStr.([]string)
for i, v := range aHeadersValSl {
aHeadersValSl[i] = textproto.CanonicalMIMEHeaderKey(v)
}
if !strutil.EquivalentSlices(configEntry.AllowedSTSHeaderValues, aHeadersValSl) {
// NOT setting changedCreds here, since this isn't really cached
configEntry.AllowedSTSHeaderValues = aHeadersValSl
changedOtherConfig = true
}
} else if req.Operation == logical.CreateOperation {
ah, ok := data.GetOk("allowed_sts_header_values")
if ok {
configEntry.AllowedSTSHeaderValues = ah.([]string)
}
}

maxRetriesInt, ok := data.GetOk("max_retries")
if ok {
configEntry.MaxRetries = maxRetriesInt.(int)
Expand Down Expand Up @@ -293,14 +322,27 @@ func (b *backend) pathConfigClientCreateUpdate(ctx context.Context, req *logical
// Struct to hold 'aws_access_key' and 'aws_secret_key' that are required to
// interact with the AWS EC2 API.
type clientConfig struct {
AccessKey string `json:"access_key"`
SecretKey string `json:"secret_key"`
Endpoint string `json:"endpoint"`
IAMEndpoint string `json:"iam_endpoint"`
STSEndpoint string `json:"sts_endpoint"`
STSRegion string `json:"sts_region"`
IAMServerIdHeaderValue string `json:"iam_server_id_header_value"`
MaxRetries int `json:"max_retries"`
AccessKey string `json:"access_key"`
SecretKey string `json:"secret_key"`
Endpoint string `json:"endpoint"`
IAMEndpoint string `json:"iam_endpoint"`
STSEndpoint string `json:"sts_endpoint"`
STSRegion string `json:"sts_region"`
IAMServerIdHeaderValue string `json:"iam_server_id_header_value"`
AllowedSTSHeaderValues []string `json:"allowed_sts_header_values"`
MaxRetries int `json:"max_retries"`
}

func (c *clientConfig) validateAllowedSTSHeaderValues(headers http.Header) error {
for k := range headers {
h := textproto.CanonicalMIMEHeaderKey(k)
if strings.HasPrefix(h, amzHeaderPrefix) &&
!strutil.StrListContains(defaultAllowedSTSRequestHeaders, h) &&
!strutil.StrListContains(c.AllowedSTSHeaderValues, h) {
return errors.New("invalid request header: " + k)
}
}
return nil
}

const pathConfigClientHelpSyn = `
Expand Down
51 changes: 49 additions & 2 deletions builtin/credential/aws/path_login.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"encoding/base64"
"encoding/pem"
"encoding/xml"
"errors"
"fmt"
"io/ioutil"
"net/http"
Expand Down Expand Up @@ -43,6 +44,11 @@ const (
retryWaitMax = 30 * time.Second
)

var (
errRequestBodyNotValid = errors.New("iam request body is invalid")
errInvalidGetCallerIdentityResponse = errors.New("body of GetCallerIdentity is invalid")
)

func (b *backend) pathLogin() *framework.Path {
return &framework.Path{
Pattern: "login$",
Expand Down Expand Up @@ -1179,7 +1185,10 @@ func (b *backend) pathLoginUpdateIam(ctx context.Context, req *logical.Request,
if err != nil {
return logical.ErrorResponse("error parsing iam_request_url"), nil
}

if parsedUrl.RawQuery != "" {
// Should be no query parameters
return logical.ErrorResponse(logical.ErrInvalidRequest.Error()), nil
}
// TODO: There are two potentially valid cases we're not yet supporting that would
// necessitate this check being changed. First, if we support GET requests.
// Second if we support presigned POST requests
Expand All @@ -1192,6 +1201,9 @@ func (b *backend) pathLoginUpdateIam(ctx context.Context, req *logical.Request,
return logical.ErrorResponse("failed to base64 decode iam_request_body"), nil
}
body := string(bodyRaw)
if err = validateLoginIamRequestBody(body); err != nil {
return logical.ErrorResponse(err.Error()), nil
}

headers := data.Get("iam_request_headers").(http.Header)
if len(headers) == 0 {
Expand All @@ -1213,6 +1225,9 @@ func (b *backend) pathLoginUpdateIam(ctx context.Context, req *logical.Request,
return logical.ErrorResponse(fmt.Sprintf("error validating %s header: %v", iamServerIdHeader, err)), nil
}
}
if err = config.validateAllowedSTSHeaderValues(headers); err != nil {
return logical.ErrorResponse(err.Error()), nil
}
if config.STSEndpoint != "" {
endpoint = config.STSEndpoint
}
Expand Down Expand Up @@ -1394,6 +1409,29 @@ func (b *backend) pathLoginUpdateIam(ctx context.Context, req *logical.Request,
}, nil
}

// Validate that the iam_request_body passed is valid for the STS request
func validateLoginIamRequestBody(body string) error {
qs, err := url.ParseQuery(body)
if err != nil {
return err
}
for k, v := range qs {
switch k {
case "Action":
if len(v) != 1 || v[0] != "GetCallerIdentity" {
return errRequestBodyNotValid
}
case "Version":
// Will assume for now that future versions don't change
// the semantics
default:
// Not expecting any other values
return errRequestBodyNotValid
}
}
return nil
}

// These two methods (hasValuesFor*) return two bools
// The first is a hasAll, that is, does the request have all the values
// necessary for this auth method
Expand Down Expand Up @@ -1559,8 +1597,12 @@ func ensureHeaderIsSigned(signedHeaders, headerToSign string) error {
}

func parseGetCallerIdentityResponse(response string) (GetCallerIdentityResponse, error) {
decoder := xml.NewDecoder(strings.NewReader(response))
result := GetCallerIdentityResponse{}
response = strings.TrimSpace(response)
if !strings.HasPrefix(response, "<GetCallerIdentityResponse") && !strings.HasPrefix(response, "<?xml") {
return result, errInvalidGetCallerIdentityResponse
}
decoder := xml.NewDecoder(strings.NewReader(response))
err := decoder.Decode(&result)
return result, err
}
Expand Down Expand Up @@ -1596,6 +1638,11 @@ func submitCallerIdentityRequest(ctx context.Context, maxRetries int, method, en
if response != nil {
defer response.Body.Close()
}
// Validate that the response type is XML
if ct := response.Header.Get("Content-Type"); ct != "text/xml" {
return nil, errInvalidGetCallerIdentityResponse
}

// we check for status code afterwards to also print out response body
responseBody, err := ioutil.ReadAll(response.Body)
if err != nil {
Expand Down
17 changes: 16 additions & 1 deletion builtin/credential/aws/path_login_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -305,6 +305,19 @@ func TestBackend_pathLogin_IAMHeaders(t *testing.T) {
},
ExpectErr: missingHeaderErr,
},
{
Name: "Map-illegal-header",
Header: map[string]interface{}{
"Content-Length": "43",
"Content-Type": "application/x-www-form-urlencoded; charset=utf-8",
"User-Agent": "aws-sdk-go/1.14.24 (go1.11; darwin; amd64)",
"X-Amz-Date": "20180910T203328Z",
"Authorization": "AWS4-HMAC-SHA256 Credential=AKIAJPQ466AIIQW4LPSQ/20180910/us-east-1/sts/aws4_request, SignedHeaders=content-length;content-type;host;x-amz-date;x-vault-aws-iam-server-id, Signature=cdef5819b2e97f1ff0f3e898fd2621aa03af00a4ec3e019122c20e5482534bf4",
"X-Vault-Aws-Iam-Server-Id": "VaultAcceptanceTesting",
"X-Amz-Mallory-Header": "<?xml><h4ck0r/>",
},
ExpectErr: errors.New("invalid request header: X-Amz-Mallory-Header"),
},
{
Name: "JSON-complete",
Header: `{
Expand Down Expand Up @@ -543,7 +556,8 @@ func setupIAMTestServer() *httptest.Server {
<ResponseMetadata>
<RequestId>7f4fc40c-853a-11e6-8848-8d035d01eb87</RequestId>
</ResponseMetadata>
</GetCallerIdentityResponse>`
</GetCallerIdentityResponse>
`

auth := r.Header.Get("Authorization")
parts := strings.Split(auth, ",")
Expand All @@ -566,6 +580,7 @@ func setupIAMTestServer() *httptest.Server {
if matchingCount != len(expectedAuthParts) {
responseString = "missing auth parts"
}
w.Header().Add("Content-Type", "text/xml")
fmt.Fprintln(w, responseString)
}))
}
Expand Down
6 changes: 5 additions & 1 deletion website/pages/api-docs/auth/aws/index.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,11 @@ capabilities, the credentials are fetched automatically.
signed headers validated by AWS. This is to protect against different types of
replay attacks, for example a signed request sent to a dev server being resent
to a production server. Consider setting this to the Vault server's DNS name.

- `allowed_sts_header_values` `(string: "")` A comma separated list of
additional request headers permitted when providing the iam_request_headers for
an IAM based login call. In any case, a default list of headers AWS STS
expects for a GetCallerIdentity are allowed.

### Sample Payload

```json
Expand Down