Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

aws/signer/v4: Ensure logger option passed down to sign method #964

Merged
merged 2 commits into from
Dec 12, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions aws/signer/v4/middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -247,18 +247,18 @@ func (s *SignHTTPRequestMiddleware) HandleFinalize(ctx context.Context, in middl
return out, metadata, &SigningError{Err: fmt.Errorf("failed to retrieve credentials: %w", err)}
}

err = s.signer.SignHTTP(ctx, credentials, req.Request, payloadHash, signingName, signingRegion, sdk.NowTime(), s.addSignerOptions)
err = s.signer.SignHTTP(ctx, credentials, req.Request, payloadHash, signingName, signingRegion, sdk.NowTime(),
func(o *SignerOptions) {
o.Logger = middleware.GetLogger(ctx)
o.LogSigning = s.logSigning
})
if err != nil {
return out, metadata, &SigningError{Err: fmt.Errorf("failed to sign http request, %w", err)}
}

return next.HandleFinalize(ctx, in)
}

func (s *SignHTTPRequestMiddleware) addSignerOptions(options *SignerOptions) {
options.LogSigning = s.logSigning
}

func haveCredentialProvider(p aws.CredentialsProvider) bool {
if p == nil {
return false
Expand Down
37 changes: 36 additions & 1 deletion aws/signer/v4/middleware_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,14 @@ import (
"io"
"net/http"
"strconv"
"strings"
"testing"
"time"

"github.com/aws/aws-sdk-go-v2/aws"
awsmiddleware "github.com/aws/aws-sdk-go-v2/aws/middleware"
"github.com/aws/aws-sdk-go-v2/internal/awstesting/unit"
"github.com/awslabs/smithy-go/logging"
"github.com/awslabs/smithy-go/middleware"
smithyhttp "github.com/awslabs/smithy-go/transport/http"
)
Expand Down Expand Up @@ -91,6 +93,7 @@ func TestSignHTTPRequestMiddleware(t *testing.T) {
cases := map[string]struct {
creds aws.CredentialsProvider
hash string
logSigning bool
expectedErr error
}{
"success": {
Expand All @@ -108,6 +111,11 @@ func TestSignHTTPRequestMiddleware(t *testing.T) {
"nil creds": {
creds: nil,
},
"with log signing": {
creds: unit.StubCredentialsProvider{},
hash: "0123456789abcdef",
logSigning: true,
},
}

const (
Expand All @@ -122,8 +130,20 @@ func TestSignHTTPRequestMiddleware(t *testing.T) {
signer: httpSignerFunc(
func(ctx context.Context,
credentials aws.Credentials, r *http.Request, payloadHash string,
service string, region string, signingTime time.Time, _ ...func(*SignerOptions),
service string, region string, signingTime time.Time,
optFns ...func(*SignerOptions),
) error {
var options SignerOptions
for _, fn := range optFns {
fn(&options)
}
if options.Logger == nil {
t.Errorf("expect logger, got none")
}
if options.LogSigning {
options.Logger.Logf(logging.Debug, t.Name())
}

expectCreds, _ := unit.StubCredentialsProvider{}.Retrieve(context.Background())
if e, a := expectCreds, credentials; e != a {
t.Errorf("expected %v, got %v", e, a)
Expand All @@ -139,6 +159,7 @@ func TestSignHTTPRequestMiddleware(t *testing.T) {
}
return nil
}),
logSigning: tt.logSigning,
}

next := middleware.FinalizeHandlerFunc(func(ctx context.Context, in middleware.FinalizeInput) (out middleware.FinalizeOutput, metadata middleware.Metadata, err error) {
Expand All @@ -149,6 +170,10 @@ func TestSignHTTPRequestMiddleware(t *testing.T) {
awsmiddleware.SetSigningName(context.Background(), signingName),
signingRegion)

var loggerBuf bytes.Buffer
logger := logging.NewStandardLogger(&loggerBuf)
ctx = middleware.SetLogger(ctx, logger)

if len(tt.hash) != 0 {
ctx = context.WithValue(ctx, payloadHashKey{}, tt.hash)
}
Expand All @@ -166,6 +191,16 @@ func TestSignHTTPRequestMiddleware(t *testing.T) {
} else if err == nil && tt.expectedErr != nil {
t.Errorf("expected error, got nil")
}

if tt.logSigning {
if e, a := t.Name(), loggerBuf.String(); !strings.Contains(a, e) {
t.Errorf("expect %v logged in %v", e, a)
}
} else {
if loggerBuf.Len() != 0 {
t.Errorf("expect no log, got %v", loggerBuf.String())
}
}
})
}
}
Expand Down
10 changes: 5 additions & 5 deletions aws/signer/v4/presign_middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,11 @@ func (s *PresignHTTPRequestMiddleware) HandleFinalize(
}

u, h, err := s.presigner.PresignHTTP(ctx, credentials,
httpReq, payloadHash, signingName, signingRegion, sdk.NowTime(), s.addSignerOptions)
httpReq, payloadHash, signingName, signingRegion, sdk.NowTime(),
func(o *SignerOptions) {
o.Logger = middleware.GetLogger(ctx)
o.LogSigning = s.logSigning
})
if err != nil {
return out, metadata, &SigningError{
Err: fmt.Errorf("failed to sign http request, %w", err),
Expand All @@ -121,7 +125,3 @@ func (s *PresignHTTPRequestMiddleware) HandleFinalize(

return out, metadata, nil
}

func (s *PresignHTTPRequestMiddleware) addSignerOptions(options *SignerOptions) {
options.LogSigning = s.logSigning
}
50 changes: 46 additions & 4 deletions aws/signer/v4/presign_middleware_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package v4

import (
"bytes"
"context"
"net/http"
"net/url"
Expand All @@ -11,6 +12,7 @@ import (
"github.com/aws/aws-sdk-go-v2/aws"
awsmiddleware "github.com/aws/aws-sdk-go-v2/aws/middleware"
"github.com/aws/aws-sdk-go-v2/internal/awstesting/unit"
"github.com/awslabs/smithy-go/logging"
"github.com/awslabs/smithy-go/middleware"
smithyhttp "github.com/awslabs/smithy-go/transport/http"
"github.com/google/go-cmp/cmp"
Expand All @@ -37,6 +39,7 @@ func TestPresignHTTPRequestMiddleware(t *testing.T) {
Request *http.Request
Creds aws.CredentialsProvider
PayloadHash string
LogSigning bool
ExpectResult *PresignedHTTPRequest
ExpectErr string
}{
Expand All @@ -55,7 +58,6 @@ func TestPresignHTTPRequestMiddleware(t *testing.T) {
SignedHeader: http.Header{},
},
},

"error": {
Request: func() *http.Request {
return &http.Request{}
Expand All @@ -64,7 +66,6 @@ func TestPresignHTTPRequestMiddleware(t *testing.T) {
PayloadHash: "",
ExpectErr: "failed to sign request",
},

"anonymous creds": {
Request: &http.Request{
URL: func() *url.URL {
Expand All @@ -81,7 +82,6 @@ func TestPresignHTTPRequestMiddleware(t *testing.T) {
SignedHeader: http.Header{},
},
},

"nil creds": {
Request: &http.Request{
URL: func() *url.URL {
Expand All @@ -96,6 +96,23 @@ func TestPresignHTTPRequestMiddleware(t *testing.T) {
SignedHeader: http.Header{},
},
},
"with log signing": {
Request: &http.Request{
URL: func() *url.URL {
u, _ := url.Parse("https://example.aws/path?query=foo")
return u
}(),
Header: http.Header{},
},
Creds: unit.StubCredentialsProvider{},
PayloadHash: "0123456789abcdef",
ExpectResult: &PresignedHTTPRequest{
URL: "https://example.aws/path?query=foo",
SignedHeader: http.Header{},
},

LogSigning: true,
},
}

const (
Expand All @@ -111,8 +128,18 @@ func TestPresignHTTPRequestMiddleware(t *testing.T) {
presigner: httpPresignerFunc(func(
ctx context.Context, credentials aws.Credentials, r *http.Request,
payloadHash string, service string, region string, signingTime time.Time,
_ ...func(*SignerOptions),
optFns ...func(*SignerOptions),
) (url string, signedHeader http.Header, err error) {
var options SignerOptions
for _, fn := range optFns {
fn(&options)
}
if options.Logger == nil {
t.Errorf("expect logger, got none")
}
if options.LogSigning {
options.Logger.Logf(logging.Debug, t.Name())
}

if !haveCredentialProvider(c.Creds) {
t.Errorf("expect presigner not to be called for not credentials provider")
Expand All @@ -134,6 +161,7 @@ func TestPresignHTTPRequestMiddleware(t *testing.T) {

return c.ExpectResult.URL, c.ExpectResult.SignedHeader, nil
}),
logSigning: c.LogSigning,
}

next := middleware.FinalizeHandlerFunc(
Expand All @@ -148,6 +176,10 @@ func TestPresignHTTPRequestMiddleware(t *testing.T) {
awsmiddleware.SetSigningName(context.Background(), signingName),
signingRegion)

var loggerBuf bytes.Buffer
logger := logging.NewStandardLogger(&loggerBuf)
ctx = middleware.SetLogger(ctx, logger)

if len(c.PayloadHash) != 0 {
ctx = context.WithValue(ctx, payloadHashKey{}, c.PayloadHash)
}
Expand All @@ -173,6 +205,16 @@ func TestPresignHTTPRequestMiddleware(t *testing.T) {
if diff := cmp.Diff(c.ExpectResult, result.Result); len(diff) != 0 {
t.Errorf("expect result match\n%v", diff)
}

if c.LogSigning {
if e, a := t.Name(), loggerBuf.String(); !strings.Contains(a, e) {
t.Errorf("expect %v logged in %v", e, a)
}
} else {
if loggerBuf.Len() != 0 {
t.Errorf("expect no log, got %v", loggerBuf.String())
}
}
})
}
}
Expand Down