Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

pull retry loop forward to cover everything from resolving auth scheme onward #2966

Merged
merged 3 commits into from
Jan 17, 2025
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,7 @@ public void writeAdditionalFiles(
if (supportsComputeInputChecksumsWorkflow) {
goDelegator.useShapeWriter(service, writer -> {
generateInputComputedChecksumMetadataHelpers(writer, model, symbolProvider, service);
writePackageLevelAddInputChecksumMiddleware(writer);
});
}

Expand Down Expand Up @@ -257,16 +258,14 @@ private void writeInputMiddlewareHelper(
writer.openBlock("func $L(stack *middleware.Stack, options Options) error {", "}",
getAddInputMiddlewareFuncName(operationName), () -> {
writer.write("""
return $T(stack, $T{
return addInputChecksumMiddleware(stack, $T{
GetAlgorithm: $L,
RequireChecksum: $L,
RequestChecksumCalculation: options.RequestChecksumCalculation,
EnableTrailingChecksum: $L,
EnableComputeSHA256PayloadHash: true,
EnableDecodedContentLengthHeader: $L,
})""",
SymbolUtils.createValueSymbolBuilder("AddInputMiddleware",
AwsGoDependency.SERVICE_INTERNAL_CHECKSUM).build(),
SymbolUtils.createValueSymbolBuilder("InputMiddlewareOptions",
AwsGoDependency.SERVICE_INTERNAL_CHECKSUM).build(),
hasRequestAlgorithmMember ?
Expand All @@ -279,6 +278,48 @@ private void writeInputMiddlewareHelper(
writer.insertTrailingNewline();
}

// adapted (service/internal/checksum).AddInputMiddleware to give the service client control over its middleware stack,
// per #2507
private void writePackageLevelAddInputChecksumMiddleware(GoWriter writer) {
writer.addUseImports(SmithyGoDependency.SMITHY_MIDDLEWARE);
writer.addUseImports(AwsGoDependency.SERVICE_INTERNAL_CHECKSUM);
writer.write("""
func addInputChecksumMiddleware(stack *middleware.Stack, options internalChecksum.InputMiddlewareOptions) (err error) {
err = stack.Initialize.Add(&internalChecksum.SetupInputContext{
GetAlgorithm: options.GetAlgorithm,
RequireChecksum: options.RequireChecksum,
RequestChecksumCalculation: options.RequestChecksumCalculation,
}, middleware.Before)
if err != nil {
return err
}

stack.Build.Remove("ContentChecksum")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I know we had this in place since before you made this change, but is it because ContentChecksum is the older way of doing checksums?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, ContentChecksum is the old MD5, which the recent changes to S3 object integrity effectively removes.


inputChecksum := &internalChecksum.ComputeInputPayloadChecksum{
EnableTrailingChecksum: options.EnableTrailingChecksum,
EnableComputePayloadHash: options.EnableComputeSHA256PayloadHash,
EnableDecodedContentLengthHeader: options.EnableDecodedContentLengthHeader,
}
if err := stack.Finalize.Insert(inputChecksum, "ResolveEndpointV2", middleware.After); err != nil {
return err
}

if options.EnableTrailingChecksum {
trailerMiddleware := &internalChecksum.AddInputChecksumTrailer{
EnableTrailingChecksum: inputChecksum.EnableTrailingChecksum,
EnableComputePayloadHash: inputChecksum.EnableComputePayloadHash,
EnableDecodedContentLengthHeader: inputChecksum.EnableDecodedContentLengthHeader,
}
if err := stack.Finalize.Insert(trailerMiddleware, inputChecksum.ID(), middleware.After); err != nil {
return err
}
}

return nil
}""");
}

private void writeOutputMiddlewareHelper(
GoWriter writer,
Model model,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ func addRetry(stack *middleware.Stack, o Options) error {
m.LogAttempts = o.ClientLogMode.IsRetries()
m.OperationMeter = o.MeterProvider.Meter($S)
})
if err := stack.Finalize.Insert(attempt, "Signing", middleware.Before); err != nil {
if err := stack.Finalize.Insert(attempt, "ResolveAuthScheme", middleware.Before); err != nil {
return err
}
if err := stack.Finalize.Insert(&retry.MetricsHeader{}, attempt.ID(), middleware.After); err != nil {
Expand Down
37 changes: 9 additions & 28 deletions service/internal/checksum/middleware_add.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,32 +52,13 @@ type InputMiddlewareOptions struct {

// AddInputMiddleware adds the middleware for performing checksum computing
// of request payloads, and checksum validation of response payloads.
//
// Deprecated: This internal-only runtime API is frozen. Do not call or modify
// it in new code. Checksum-enabled service operations now generate this
// middleware setup code inline per #2507.
func AddInputMiddleware(stack *middleware.Stack, options InputMiddlewareOptions) (err error) {
// TODO ensure this works correctly with presigned URLs
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(it does)


// Middleware stack:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is so beyond out-of-date at this point I've just removed it

// * (OK)(Initialize) --none--
// * (OK)(Serialize) EndpointResolver
// * (OK)(Build) ComputeContentLength
// * (AD)(Build) Header ComputeInputPayloadChecksum
// * SIGNED Payload - If HTTP && not support trailing checksum
// * UNSIGNED Payload - If HTTPS && not support trailing checksum
// * (RM)(Build) ContentChecksum - OK to remove
// * (OK)(Build) ComputePayloadHash
// * v4.dynamicPayloadSigningMiddleware
// * v4.computePayloadSHA256
// * v4.unsignedPayload
// (OK)(Build) Set computedPayloadHash header
// * (OK)(Finalize) Retry
// * (AD)(Finalize) Trailer ComputeInputPayloadChecksum,
// * Requires HTTPS && support trailing checksum
// * UNSIGNED Payload
// * Finalize run if HTTPS && support trailing checksum
// * (OK)(Finalize) Signing
// * (OK)(Deserialize) --none--

// Initial checksum configuration look up middleware
err = stack.Initialize.Add(&setupInputContext{
err = stack.Initialize.Add(&SetupInputContext{
GetAlgorithm: options.GetAlgorithm,
RequireChecksum: options.RequireChecksum,
RequestChecksumCalculation: options.RequestChecksumCalculation,
Expand All @@ -88,7 +69,7 @@ func AddInputMiddleware(stack *middleware.Stack, options InputMiddlewareOptions)

stack.Build.Remove("ContentChecksum")

inputChecksum := &computeInputPayloadChecksum{
inputChecksum := &ComputeInputPayloadChecksum{
EnableTrailingChecksum: options.EnableTrailingChecksum,
EnableComputePayloadHash: options.EnableComputeSHA256PayloadHash,
EnableDecodedContentLengthHeader: options.EnableDecodedContentLengthHeader,
Expand All @@ -99,7 +80,7 @@ func AddInputMiddleware(stack *middleware.Stack, options InputMiddlewareOptions)

// If trailing checksum is not supported no need for finalize handler to be added.
if options.EnableTrailingChecksum {
trailerMiddleware := &addInputChecksumTrailer{
trailerMiddleware := &AddInputChecksumTrailer{
EnableTrailingChecksum: inputChecksum.EnableTrailingChecksum,
EnableComputePayloadHash: inputChecksum.EnableComputePayloadHash,
EnableDecodedContentLengthHeader: inputChecksum.EnableDecodedContentLengthHeader,
Expand All @@ -115,10 +96,10 @@ func AddInputMiddleware(stack *middleware.Stack, options InputMiddlewareOptions)
// RemoveInputMiddleware Removes the compute input payload checksum middleware
// handlers from the stack.
func RemoveInputMiddleware(stack *middleware.Stack) {
id := (*setupInputContext)(nil).ID()
id := (*SetupInputContext)(nil).ID()
stack.Initialize.Remove(id)

id = (*computeInputPayloadChecksum)(nil).ID()
id = (*ComputeInputPayloadChecksum)(nil).ID()
stack.Finalize.Remove(id)
}

Expand Down
32 changes: 16 additions & 16 deletions service/internal/checksum/middleware_add_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@ func TestAddInputMiddleware(t *testing.T) {
options InputMiddlewareOptions
expectErr string
expectMiddleware []string
expectInitialize *setupInputContext
expectFinalize *computeInputPayloadChecksum
expectInitialize *SetupInputContext
expectFinalize *ComputeInputPayloadChecksum
}{
"with trailing checksum": {
options: InputMiddlewareOptions{
Expand All @@ -46,12 +46,12 @@ func TestAddInputMiddleware(t *testing.T) {
"Signing",
"Deserialize stack step",
},
expectInitialize: &setupInputContext{
expectInitialize: &SetupInputContext{
GetAlgorithm: func(interface{}) (string, bool) {
return string(AlgorithmCRC32), true
},
},
expectFinalize: &computeInputPayloadChecksum{
expectFinalize: &ComputeInputPayloadChecksum{
EnableTrailingChecksum: true,
EnableComputePayloadHash: true,
EnableDecodedContentLengthHeader: true,
Expand Down Expand Up @@ -81,12 +81,12 @@ func TestAddInputMiddleware(t *testing.T) {
"Signing",
"Deserialize stack step",
},
expectInitialize: &setupInputContext{
expectInitialize: &SetupInputContext{
GetAlgorithm: func(interface{}) (string, bool) {
return string(AlgorithmCRC32), true
},
},
expectFinalize: &computeInputPayloadChecksum{
expectFinalize: &ComputeInputPayloadChecksum{
EnableTrailingChecksum: true,
},
},
Expand All @@ -111,12 +111,12 @@ func TestAddInputMiddleware(t *testing.T) {
"Signing",
"Deserialize stack step",
},
expectInitialize: &setupInputContext{
expectInitialize: &SetupInputContext{
GetAlgorithm: func(interface{}) (string, bool) {
return string(AlgorithmCRC32), true
},
},
expectFinalize: &computeInputPayloadChecksum{},
expectFinalize: &ComputeInputPayloadChecksum{},
},
}

Expand All @@ -140,12 +140,12 @@ func TestAddInputMiddleware(t *testing.T) {
t.Fatalf("expect stack list match:\n%s", diff)
}

initializeMiddleware, ok := stack.Initialize.Get((*setupInputContext)(nil).ID())
initializeMiddleware, ok := stack.Initialize.Get((*SetupInputContext)(nil).ID())
if e, a := (c.expectInitialize != nil), ok; e != a {
t.Errorf("expect initialize middleware %t, got %t", e, a)
}
if c.expectInitialize != nil && ok {
setupInput := initializeMiddleware.(*setupInputContext)
setupInput := initializeMiddleware.(*SetupInputContext)
if e, a := c.options.GetAlgorithm != nil, setupInput.GetAlgorithm != nil; e != a {
t.Fatalf("expect GetAlgorithm %t, got %t", e, a)
}
Expand All @@ -159,20 +159,20 @@ func TestAddInputMiddleware(t *testing.T) {
}
}

finalizeMW, ok := stack.Finalize.Get((*computeInputPayloadChecksum)(nil).ID())
finalizeMW, ok := stack.Finalize.Get((*ComputeInputPayloadChecksum)(nil).ID())
if e, a := (c.expectFinalize != nil), ok; e != a {
t.Errorf("expect build middleware %t, got %t", e, a)
}
var computeInput *computeInputPayloadChecksum
var ComputeInput *ComputeInputPayloadChecksum
if c.expectFinalize != nil && ok {
computeInput = finalizeMW.(*computeInputPayloadChecksum)
if e, a := c.expectFinalize.EnableTrailingChecksum, computeInput.EnableTrailingChecksum; e != a {
ComputeInput = finalizeMW.(*ComputeInputPayloadChecksum)
if e, a := c.expectFinalize.EnableTrailingChecksum, ComputeInput.EnableTrailingChecksum; e != a {
t.Errorf("expect %v enable trailing checksum, got %v", e, a)
}
if e, a := c.expectFinalize.EnableComputePayloadHash, computeInput.EnableComputePayloadHash; e != a {
if e, a := c.expectFinalize.EnableComputePayloadHash, ComputeInput.EnableComputePayloadHash; e != a {
t.Errorf("expect %v enable compute payload hash, got %v", e, a)
}
if e, a := c.expectFinalize.EnableDecodedContentLengthHeader, computeInput.EnableDecodedContentLengthHeader; e != a {
if e, a := c.expectFinalize.EnableDecodedContentLengthHeader, ComputeInput.EnableDecodedContentLengthHeader; e != a {
t.Errorf("expect %v enable decoded length header, got %v", e, a)
}
}
Expand Down
18 changes: 9 additions & 9 deletions service/internal/checksum/middleware_compute_input_checksum.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,8 @@ func SetComputedInputChecksums(m *middleware.Metadata, vs map[string]string) {
m.Set(computedInputChecksumsKey{}, vs)
}

// computeInputPayloadChecksum middleware computes payload checksum
type computeInputPayloadChecksum struct {
// ComputeInputPayloadChecksum middleware computes payload checksum
type ComputeInputPayloadChecksum struct {
// Enables support for wrapping the serialized input payload with a
// content-encoding: aws-check wrapper, and including a trailer for the
// algorithm's checksum value.
Expand Down Expand Up @@ -71,7 +71,7 @@ type computeInputPayloadChecksum struct {
type useTrailer struct{}

// ID provides the middleware's identifier.
func (m *computeInputPayloadChecksum) ID() string {
func (m *ComputeInputPayloadChecksum) ID() string {
return "AWSChecksum:ComputeInputPayloadChecksum"
}

Expand All @@ -91,14 +91,14 @@ func (e computeInputHeaderChecksumError) Error() string {
}
func (e computeInputHeaderChecksumError) Unwrap() error { return e.Err }

// HandleBuild handles computing the payload's checksum, in the following cases:
// HandleFinalize handles computing the payload's checksum, in the following cases:
// - Is HTTP, not HTTPS
// - RequireChecksum is true, and no checksums were specified via the Input
// - Trailing checksums are not supported
//
// The build handler must be inserted in the stack before ContentPayloadHash
// and after ComputeContentLength.
func (m *computeInputPayloadChecksum) HandleFinalize(
func (m *ComputeInputPayloadChecksum) HandleFinalize(
ctx context.Context, in middleware.FinalizeInput, next middleware.FinalizeHandler,
) (
out middleware.FinalizeOutput, metadata middleware.Metadata, err error,
Expand Down Expand Up @@ -228,23 +228,23 @@ func (e computeInputTrailingChecksumError) Error() string {
}
func (e computeInputTrailingChecksumError) Unwrap() error { return e.Err }

// addInputChecksumTrailer
// AddInputChecksumTrailer adds HTTP checksum when
// - Is HTTPS, not HTTP
// - A checksum was specified via the Input
// - Trailing checksums are supported.
type addInputChecksumTrailer struct {
type AddInputChecksumTrailer struct {
EnableTrailingChecksum bool
EnableComputePayloadHash bool
EnableDecodedContentLengthHeader bool
}

// ID identifies this middleware.
func (*addInputChecksumTrailer) ID() string {
func (*AddInputChecksumTrailer) ID() string {
return "addInputChecksumTrailer"
}

// HandleFinalize wraps the request body to write the trailing checksum.
func (m *addInputChecksumTrailer) HandleFinalize(
func (m *AddInputChecksumTrailer) HandleFinalize(
ctx context.Context, in middleware.FinalizeInput, next middleware.FinalizeHandler,
) (
out middleware.FinalizeOutput, metadata middleware.Metadata, err error,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ import (

func TestComputeInputPayloadChecksum(t *testing.T) {
cases := map[string]map[string]struct {
optionsFn func(*computeInputPayloadChecksum)
optionsFn func(*ComputeInputPayloadChecksum)
initContext func(context.Context) context.Context
buildInput middleware.BuildInput

Expand Down Expand Up @@ -392,7 +392,7 @@ func TestComputeInputPayloadChecksum(t *testing.T) {
initContext: func(ctx context.Context) context.Context {
return internalcontext.SetChecksumInputAlgorithm(ctx, string(AlgorithmCRC32))
},
optionsFn: func(o *computeInputPayloadChecksum) {
optionsFn: func(o *ComputeInputPayloadChecksum) {
o.EnableComputePayloadHash = false
},
buildInput: middleware.BuildInput{
Expand All @@ -414,7 +414,7 @@ func TestComputeInputPayloadChecksum(t *testing.T) {
},
},
"https no trailing checksum": {
optionsFn: func(o *computeInputPayloadChecksum) {
optionsFn: func(o *ComputeInputPayloadChecksum) {
o.EnableTrailingChecksum = false
},
initContext: func(ctx context.Context) context.Context {
Expand All @@ -439,7 +439,7 @@ func TestComputeInputPayloadChecksum(t *testing.T) {
},
},
"with content encoding set": {
optionsFn: func(o *computeInputPayloadChecksum) {
optionsFn: func(o *ComputeInputPayloadChecksum) {
o.EnableTrailingChecksum = false
},
initContext: func(ctx context.Context) context.Context {
Expand Down Expand Up @@ -550,7 +550,7 @@ func TestComputeInputPayloadChecksum(t *testing.T) {
expectBuildErr: true,
},
"https no trailing unseekable stream": {
optionsFn: func(o *computeInputPayloadChecksum) {
optionsFn: func(o *ComputeInputPayloadChecksum) {
o.EnableTrailingChecksum = false
},
initContext: func(ctx context.Context) context.Context {
Expand Down Expand Up @@ -677,7 +677,7 @@ func TestComputeInputPayloadChecksum(t *testing.T) {
initContext: func(ctx context.Context) context.Context {
return internalcontext.SetChecksumInputAlgorithm(ctx, string(AlgorithmCRC32))
},
optionsFn: func(o *computeInputPayloadChecksum) {
optionsFn: func(o *ComputeInputPayloadChecksum) {
o.EnableComputePayloadHash = false
},
buildInput: middleware.BuildInput{
Expand All @@ -702,7 +702,7 @@ func TestComputeInputPayloadChecksum(t *testing.T) {
},
},
"https no decode content length": {
optionsFn: func(o *computeInputPayloadChecksum) {
optionsFn: func(o *ComputeInputPayloadChecksum) {
o.EnableDecodedContentLengthHeader = false
},
initContext: func(ctx context.Context) context.Context {
Expand Down Expand Up @@ -763,7 +763,7 @@ func TestComputeInputPayloadChecksum(t *testing.T) {
t.Run(name, func(t *testing.T) {
for name, c := range cs {
t.Run(name, func(t *testing.T) {
m := &computeInputPayloadChecksum{
m := &ComputeInputPayloadChecksum{
EnableTrailingChecksum: true,
EnableComputePayloadHash: true,
EnableDecodedContentLengthHeader: true,
Expand All @@ -772,7 +772,7 @@ func TestComputeInputPayloadChecksum(t *testing.T) {
if c.optionsFn != nil {
c.optionsFn(m)
}
trailerMiddleware := &addInputChecksumTrailer{
trailerMiddleware := &AddInputChecksumTrailer{
EnableTrailingChecksum: m.EnableTrailingChecksum,
EnableComputePayloadHash: m.EnableComputePayloadHash,
EnableDecodedContentLengthHeader: m.EnableDecodedContentLengthHeader,
Expand Down
Loading