Skip to content

Commit

Permalink
aws/signer/v4: Ensure logger option passed down to sign method (#964)
Browse files Browse the repository at this point in the history
Ensures the middleware's Logger option value is passed down to the signer's sign method via functional options.
  • Loading branch information
jasdel authored Dec 12, 2020
1 parent 1478538 commit 3fb2b0c
Show file tree
Hide file tree
Showing 4 changed files with 92 additions and 15 deletions.
10 changes: 5 additions & 5 deletions aws/signer/v4/middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -247,18 +247,18 @@ func (s *SignHTTPRequestMiddleware) HandleFinalize(ctx context.Context, in middl
return out, metadata, &SigningError{Err: fmt.Errorf("failed to retrieve credentials: %w", err)}
}

err = s.signer.SignHTTP(ctx, credentials, req.Request, payloadHash, signingName, signingRegion, sdk.NowTime(), s.addSignerOptions)
err = s.signer.SignHTTP(ctx, credentials, req.Request, payloadHash, signingName, signingRegion, sdk.NowTime(),
func(o *SignerOptions) {
o.Logger = middleware.GetLogger(ctx)
o.LogSigning = s.logSigning
})
if err != nil {
return out, metadata, &SigningError{Err: fmt.Errorf("failed to sign http request, %w", err)}
}

return next.HandleFinalize(ctx, in)
}

func (s *SignHTTPRequestMiddleware) addSignerOptions(options *SignerOptions) {
options.LogSigning = s.logSigning
}

func haveCredentialProvider(p aws.CredentialsProvider) bool {
if p == nil {
return false
Expand Down
37 changes: 36 additions & 1 deletion aws/signer/v4/middleware_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,14 @@ import (
"io"
"net/http"
"strconv"
"strings"
"testing"
"time"

"github.com/aws/aws-sdk-go-v2/aws"
awsmiddleware "github.com/aws/aws-sdk-go-v2/aws/middleware"
"github.com/aws/aws-sdk-go-v2/internal/awstesting/unit"
"github.com/awslabs/smithy-go/logging"
"github.com/awslabs/smithy-go/middleware"
smithyhttp "github.com/awslabs/smithy-go/transport/http"
)
Expand Down Expand Up @@ -91,6 +93,7 @@ func TestSignHTTPRequestMiddleware(t *testing.T) {
cases := map[string]struct {
creds aws.CredentialsProvider
hash string
logSigning bool
expectedErr error
}{
"success": {
Expand All @@ -108,6 +111,11 @@ func TestSignHTTPRequestMiddleware(t *testing.T) {
"nil creds": {
creds: nil,
},
"with log signing": {
creds: unit.StubCredentialsProvider{},
hash: "0123456789abcdef",
logSigning: true,
},
}

const (
Expand All @@ -122,8 +130,20 @@ func TestSignHTTPRequestMiddleware(t *testing.T) {
signer: httpSignerFunc(
func(ctx context.Context,
credentials aws.Credentials, r *http.Request, payloadHash string,
service string, region string, signingTime time.Time, _ ...func(*SignerOptions),
service string, region string, signingTime time.Time,
optFns ...func(*SignerOptions),
) error {
var options SignerOptions
for _, fn := range optFns {
fn(&options)
}
if options.Logger == nil {
t.Errorf("expect logger, got none")
}
if options.LogSigning {
options.Logger.Logf(logging.Debug, t.Name())
}

expectCreds, _ := unit.StubCredentialsProvider{}.Retrieve(context.Background())
if e, a := expectCreds, credentials; e != a {
t.Errorf("expected %v, got %v", e, a)
Expand All @@ -139,6 +159,7 @@ func TestSignHTTPRequestMiddleware(t *testing.T) {
}
return nil
}),
logSigning: tt.logSigning,
}

next := middleware.FinalizeHandlerFunc(func(ctx context.Context, in middleware.FinalizeInput) (out middleware.FinalizeOutput, metadata middleware.Metadata, err error) {
Expand All @@ -149,6 +170,10 @@ func TestSignHTTPRequestMiddleware(t *testing.T) {
awsmiddleware.SetSigningName(context.Background(), signingName),
signingRegion)

var loggerBuf bytes.Buffer
logger := logging.NewStandardLogger(&loggerBuf)
ctx = middleware.SetLogger(ctx, logger)

if len(tt.hash) != 0 {
ctx = context.WithValue(ctx, payloadHashKey{}, tt.hash)
}
Expand All @@ -166,6 +191,16 @@ func TestSignHTTPRequestMiddleware(t *testing.T) {
} else if err == nil && tt.expectedErr != nil {
t.Errorf("expected error, got nil")
}

if tt.logSigning {
if e, a := t.Name(), loggerBuf.String(); !strings.Contains(a, e) {
t.Errorf("expect %v logged in %v", e, a)
}
} else {
if loggerBuf.Len() != 0 {
t.Errorf("expect no log, got %v", loggerBuf.String())
}
}
})
}
}
Expand Down
10 changes: 5 additions & 5 deletions aws/signer/v4/presign_middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,11 @@ func (s *PresignHTTPRequestMiddleware) HandleFinalize(
}

u, h, err := s.presigner.PresignHTTP(ctx, credentials,
httpReq, payloadHash, signingName, signingRegion, sdk.NowTime(), s.addSignerOptions)
httpReq, payloadHash, signingName, signingRegion, sdk.NowTime(),
func(o *SignerOptions) {
o.Logger = middleware.GetLogger(ctx)
o.LogSigning = s.logSigning
})
if err != nil {
return out, metadata, &SigningError{
Err: fmt.Errorf("failed to sign http request, %w", err),
Expand All @@ -121,7 +125,3 @@ func (s *PresignHTTPRequestMiddleware) HandleFinalize(

return out, metadata, nil
}

func (s *PresignHTTPRequestMiddleware) addSignerOptions(options *SignerOptions) {
options.LogSigning = s.logSigning
}
50 changes: 46 additions & 4 deletions aws/signer/v4/presign_middleware_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package v4

import (
"bytes"
"context"
"net/http"
"net/url"
Expand All @@ -11,6 +12,7 @@ import (
"github.com/aws/aws-sdk-go-v2/aws"
awsmiddleware "github.com/aws/aws-sdk-go-v2/aws/middleware"
"github.com/aws/aws-sdk-go-v2/internal/awstesting/unit"
"github.com/awslabs/smithy-go/logging"
"github.com/awslabs/smithy-go/middleware"
smithyhttp "github.com/awslabs/smithy-go/transport/http"
"github.com/google/go-cmp/cmp"
Expand All @@ -37,6 +39,7 @@ func TestPresignHTTPRequestMiddleware(t *testing.T) {
Request *http.Request
Creds aws.CredentialsProvider
PayloadHash string
LogSigning bool
ExpectResult *PresignedHTTPRequest
ExpectErr string
}{
Expand All @@ -55,7 +58,6 @@ func TestPresignHTTPRequestMiddleware(t *testing.T) {
SignedHeader: http.Header{},
},
},

"error": {
Request: func() *http.Request {
return &http.Request{}
Expand All @@ -64,7 +66,6 @@ func TestPresignHTTPRequestMiddleware(t *testing.T) {
PayloadHash: "",
ExpectErr: "failed to sign request",
},

"anonymous creds": {
Request: &http.Request{
URL: func() *url.URL {
Expand All @@ -81,7 +82,6 @@ func TestPresignHTTPRequestMiddleware(t *testing.T) {
SignedHeader: http.Header{},
},
},

"nil creds": {
Request: &http.Request{
URL: func() *url.URL {
Expand All @@ -96,6 +96,23 @@ func TestPresignHTTPRequestMiddleware(t *testing.T) {
SignedHeader: http.Header{},
},
},
"with log signing": {
Request: &http.Request{
URL: func() *url.URL {
u, _ := url.Parse("https://example.aws/path?query=foo")
return u
}(),
Header: http.Header{},
},
Creds: unit.StubCredentialsProvider{},
PayloadHash: "0123456789abcdef",
ExpectResult: &PresignedHTTPRequest{
URL: "https://example.aws/path?query=foo",
SignedHeader: http.Header{},
},

LogSigning: true,
},
}

const (
Expand All @@ -111,8 +128,18 @@ func TestPresignHTTPRequestMiddleware(t *testing.T) {
presigner: httpPresignerFunc(func(
ctx context.Context, credentials aws.Credentials, r *http.Request,
payloadHash string, service string, region string, signingTime time.Time,
_ ...func(*SignerOptions),
optFns ...func(*SignerOptions),
) (url string, signedHeader http.Header, err error) {
var options SignerOptions
for _, fn := range optFns {
fn(&options)
}
if options.Logger == nil {
t.Errorf("expect logger, got none")
}
if options.LogSigning {
options.Logger.Logf(logging.Debug, t.Name())
}

if !haveCredentialProvider(c.Creds) {
t.Errorf("expect presigner not to be called for not credentials provider")
Expand All @@ -134,6 +161,7 @@ func TestPresignHTTPRequestMiddleware(t *testing.T) {

return c.ExpectResult.URL, c.ExpectResult.SignedHeader, nil
}),
logSigning: c.LogSigning,
}

next := middleware.FinalizeHandlerFunc(
Expand All @@ -148,6 +176,10 @@ func TestPresignHTTPRequestMiddleware(t *testing.T) {
awsmiddleware.SetSigningName(context.Background(), signingName),
signingRegion)

var loggerBuf bytes.Buffer
logger := logging.NewStandardLogger(&loggerBuf)
ctx = middleware.SetLogger(ctx, logger)

if len(c.PayloadHash) != 0 {
ctx = context.WithValue(ctx, payloadHashKey{}, c.PayloadHash)
}
Expand All @@ -173,6 +205,16 @@ func TestPresignHTTPRequestMiddleware(t *testing.T) {
if diff := cmp.Diff(c.ExpectResult, result.Result); len(diff) != 0 {
t.Errorf("expect result match\n%v", diff)
}

if c.LogSigning {
if e, a := t.Name(), loggerBuf.String(); !strings.Contains(a, e) {
t.Errorf("expect %v logged in %v", e, a)
}
} else {
if loggerBuf.Len() != 0 {
t.Errorf("expect no log, got %v", loggerBuf.String())
}
}
})
}
}
Expand Down

0 comments on commit 3fb2b0c

Please sign in to comment.