diff --git a/CHANGELOG_PENDING.md b/CHANGELOG_PENDING.md index 8a1927a39ca..49dd32ca261 100644 --- a/CHANGELOG_PENDING.md +++ b/CHANGELOG_PENDING.md @@ -1,5 +1,7 @@ ### SDK Features ### SDK Enhancements +* `service/neptune`: Support for PreSignedUrl generation for CopyDBClusterSnapshot and CreateDBCluster operations. ([#3782](https://github.com/aws/aws-sdk-go/pull/3782)) +* `service/docdb`: Support for PreSignedUrl generation for CopyDBClusterSnapshot and CreateDBCluster operations. ([#3782](https://github.com/aws/aws-sdk-go/pull/3782)) ### SDK Bugs diff --git a/private/model/api/customization_passes.go b/private/model/api/customization_passes.go index b9cc231ecef..a1c797b7884 100644 --- a/private/model/api/customization_passes.go +++ b/private/model/api/customization_passes.go @@ -49,6 +49,8 @@ func (a *API) customizationPasses() error { "s3control": s3ControlCustomizations, "cloudfront": cloudfrontCustomizations, "rds": rdsCustomizations, + "neptune": neptuneCustomizations, + "docdb": docdbCustomizations, // Disable endpoint resolving for services that require customer // to provide endpoint them selves. @@ -327,7 +329,34 @@ func rdsCustomizations(a *API) error { "CreateDBClusterInput", "StartDBInstanceAutomatedBackupsReplicationInput", } - for _, input := range inputs { + generatePresignedURL(a, inputs) + return nil +} + +// neptuneCustomizations are customization for the service/neptune. This adds +// non-modeled fields used for presigning. +func neptuneCustomizations(a *API) error { + inputs := []string{ + "CopyDBClusterSnapshotInput", + "CreateDBClusterInput", + } + generatePresignedURL(a, inputs) + return nil +} + +// neptuneCustomizations are customization for the service/neptune. This adds +// non-modeled fields used for presigning. +func docdbCustomizations(a *API) error { + inputs := []string{ + "CopyDBClusterSnapshotInput", + "CreateDBClusterInput", + } + generatePresignedURL(a, inputs) + return nil +} + +func generatePresignedURL(a *API, inputShapes []string) { + for _, input := range inputShapes { if ref, ok := a.Shapes[input]; ok { ref.MemberRefs["SourceRegion"] = &ShapeRef{ Documentation: docstring(`SourceRegion is the source region where the resource exists. This is not sent over the wire and is only used for presigning. This value should always have the same region as the source ARN.`), @@ -342,8 +371,6 @@ func rdsCustomizations(a *API) error { } } } - - return nil } func disableEndpointResolving(a *API) error { diff --git a/service/docdb/api.go b/service/docdb/api.go index 7035fe95dd1..bacc69d58c7 100644 --- a/service/docdb/api.go +++ b/service/docdb/api.go @@ -4964,6 +4964,9 @@ type CopyDBClusterSnapshotInput struct { // cluster snapshot, and otherwise false. The default is false. CopyTags *bool `type:"boolean"` + // DestinationRegion is used for presigning the request to a given region. + DestinationRegion *string `type:"string"` + // The AWS KMS key ID for an encrypted cluster snapshot. The AWS KMS key ID // is the Amazon Resource Name (ARN), AWS KMS key identifier, or the AWS KMS // key alias for the AWS KMS encryption key. @@ -5033,6 +5036,11 @@ type CopyDBClusterSnapshotInput struct { // SourceDBClusterSnapshotIdentifier is a required field SourceDBClusterSnapshotIdentifier *string `type:"string" required:"true"` + // SourceRegion is the source region where the resource exists. This is not + // sent over the wire and is only used for presigning. This value should always + // have the same region as the source ARN. + SourceRegion *string `type:"string" ignore:"true"` + // The tags to be assigned to the cluster snapshot. Tags []*Tag `locationNameList:"Tag" type:"list"` @@ -5085,6 +5093,12 @@ func (s *CopyDBClusterSnapshotInput) SetCopyTags(v bool) *CopyDBClusterSnapshotI return s } +// SetDestinationRegion sets the DestinationRegion field's value. +func (s *CopyDBClusterSnapshotInput) SetDestinationRegion(v string) *CopyDBClusterSnapshotInput { + s.DestinationRegion = &v + return s +} + // SetKmsKeyId sets the KmsKeyId field's value. func (s *CopyDBClusterSnapshotInput) SetKmsKeyId(v string) *CopyDBClusterSnapshotInput { s.KmsKeyId = &v @@ -5103,6 +5117,12 @@ func (s *CopyDBClusterSnapshotInput) SetSourceDBClusterSnapshotIdentifier(v stri return s } +// SetSourceRegion sets the SourceRegion field's value. +func (s *CopyDBClusterSnapshotInput) SetSourceRegion(v string) *CopyDBClusterSnapshotInput { + s.SourceRegion = &v + return s +} + // SetTags sets the Tags field's value. func (s *CopyDBClusterSnapshotInput) SetTags(v []*Tag) *CopyDBClusterSnapshotInput { s.Tags = v @@ -5188,6 +5208,9 @@ type CreateDBClusterInput struct { // deleted. DeletionProtection *bool `type:"boolean"` + // DestinationRegion is used for presigning the request to a given region. + DestinationRegion *string `type:"string"` + // A list of log types that need to be enabled for exporting to Amazon CloudWatch // Logs. You can enable audit logs or profiler logs. For more information, see // Auditing Amazon DocumentDB Events (https://docs.aws.amazon.com/documentdb/latest/developerguide/event-auditing.html) @@ -5282,6 +5305,11 @@ type CreateDBClusterInput struct { // Constraints: Minimum 30-minute window. PreferredMaintenanceWindow *string `type:"string"` + // SourceRegion is the source region where the resource exists. This is not + // sent over the wire and is only used for presigning. This value should always + // have the same region as the source ARN. + SourceRegion *string `type:"string" ignore:"true"` + // Specifies whether the cluster is encrypted. StorageEncrypted *bool `type:"boolean"` @@ -5360,6 +5388,12 @@ func (s *CreateDBClusterInput) SetDeletionProtection(v bool) *CreateDBClusterInp return s } +// SetDestinationRegion sets the DestinationRegion field's value. +func (s *CreateDBClusterInput) SetDestinationRegion(v string) *CreateDBClusterInput { + s.DestinationRegion = &v + return s +} + // SetEnableCloudwatchLogsExports sets the EnableCloudwatchLogsExports field's value. func (s *CreateDBClusterInput) SetEnableCloudwatchLogsExports(v []*string) *CreateDBClusterInput { s.EnableCloudwatchLogsExports = v @@ -5420,6 +5454,12 @@ func (s *CreateDBClusterInput) SetPreferredMaintenanceWindow(v string) *CreateDB return s } +// SetSourceRegion sets the SourceRegion field's value. +func (s *CreateDBClusterInput) SetSourceRegion(v string) *CreateDBClusterInput { + s.SourceRegion = &v + return s +} + // SetStorageEncrypted sets the StorageEncrypted field's value. func (s *CreateDBClusterInput) SetStorageEncrypted(v bool) *CreateDBClusterInput { s.StorageEncrypted = &v diff --git a/service/docdb/customizations.go b/service/docdb/customizations.go new file mode 100644 index 00000000000..f8fb85edf23 --- /dev/null +++ b/service/docdb/customizations.go @@ -0,0 +1,107 @@ +package docdb + +import ( + "time" + + "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go/aws/awsutil" + "github.com/aws/aws-sdk-go/aws/endpoints" + "github.com/aws/aws-sdk-go/aws/request" +) + +func init() { + ops := []string{ + opCopyDBClusterSnapshot, + opCreateDBCluster, + } + initRequest = func(r *request.Request) { + for _, operation := range ops { + if r.Operation.Name == operation { + r.Handlers.Build.PushFront(fillPresignedURL) + } + } + } +} + +func fillPresignedURL(r *request.Request) { + fns := map[string]func(r *request.Request){ + opCopyDBClusterSnapshot: copyDBClusterSnapshotPresign, + opCreateDBCluster: createDBClusterPresign, + } + if !r.ParamsFilled() { + return + } + if f, ok := fns[r.Operation.Name]; ok { + f(r) + } +} + +func copyDBClusterSnapshotPresign(r *request.Request) { + originParams := r.Params.(*CopyDBClusterSnapshotInput) + + if originParams.SourceRegion == nil || originParams.PreSignedUrl != nil || originParams.DestinationRegion != nil { + return + } + + originParams.DestinationRegion = r.Config.Region + // preSignedUrl is not required for instances in the same region. + if *originParams.SourceRegion == *originParams.DestinationRegion { + return + } + + newParams := awsutil.CopyOf(r.Params).(*CopyDBClusterSnapshotInput) + originParams.PreSignedUrl = presignURL(r, originParams.SourceRegion, newParams) +} + +func createDBClusterPresign(r *request.Request) { + originParams := r.Params.(*CreateDBClusterInput) + + if originParams.SourceRegion == nil || originParams.PreSignedUrl != nil || originParams.DestinationRegion != nil { + return + } + + originParams.DestinationRegion = r.Config.Region + // preSignedUrl is not required for instances in the same region. + if *originParams.SourceRegion == *originParams.DestinationRegion { + return + } + + newParams := awsutil.CopyOf(r.Params).(*CreateDBClusterInput) + originParams.PreSignedUrl = presignURL(r, originParams.SourceRegion, newParams) +} + +// presignURL will presign the request by using SoureRegion to sign with. SourceRegion is not +// sent to the service, and is only used to not have the SDKs parsing ARNs. +func presignURL(r *request.Request, sourceRegion *string, newParams interface{}) *string { + cfg := r.Config.Copy(aws.NewConfig(). + WithEndpoint(""). + WithRegion(aws.StringValue(sourceRegion))) + + clientInfo := r.ClientInfo + resolved, err := r.Config.EndpointResolver.EndpointFor( + EndpointsID, aws.StringValue(cfg.Region), + func(opt *endpoints.Options) { + opt.DisableSSL = aws.BoolValue(cfg.DisableSSL) + opt.UseDualStack = aws.BoolValue(cfg.UseDualStack) + }, + ) + if err != nil { + r.Error = err + return nil + } + + clientInfo.Endpoint = resolved.URL + clientInfo.SigningRegion = resolved.SigningRegion + + // Presign a request with modified params + req := request.New(*cfg, clientInfo, r.Handlers, r.Retryer, r.Operation, newParams, r.Data) + req.Operation.HTTPMethod = "GET" + uri, err := req.Presign(5 * time.Minute) // 5 minutes should be enough. + if err != nil { // bubble error back up to original request + r.Error = err + return nil + } + + // We have our URL, set it on params + return &uri +} diff --git a/service/docdb/customizations_test.go b/service/docdb/customizations_test.go new file mode 100644 index 00000000000..0120548158d --- /dev/null +++ b/service/docdb/customizations_test.go @@ -0,0 +1,187 @@ +// +build go1.9 + +package docdb + +import ( + "fmt" + "io/ioutil" + "net/url" + "regexp" + "testing" + "time" + + "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go/aws/request" + "github.com/aws/aws-sdk-go/awstesting" + "github.com/aws/aws-sdk-go/awstesting/unit" +) + +func TestCopyDBClusterSnapshotRequestNoPanic(t *testing.T) { + svc := New(unit.Session, &aws.Config{Region: aws.String("us-west-2")}) + + f := func() { + // Doesn't panic on nil input + req, _ := svc.CopyDBClusterSnapshotRequest(nil) + req.Sign() + } + if paniced, p := awstesting.DidPanic(f); paniced { + t.Errorf("expect no panic, got %v", p) + } +} + +func TestPresignCrossRegionRequest(t *testing.T) { + const targetRegion = "us-west-2" + + svc := New(unit.Session, &aws.Config{Region: aws.String(targetRegion)}) + + const regexPattern = `^https://rds.us-west-1\.amazonaws\.com/\?Action=%s.+?DestinationRegion=%s.+` + + cases := map[string]struct { + Req *request.Request + Assert func(*testing.T, string) + }{ + opCopyDBClusterSnapshot: { + Req: func() *request.Request { + req, _ := svc.CopyDBClusterSnapshotRequest( + &CopyDBClusterSnapshotInput{ + SourceRegion: aws.String("us-west-1"), + SourceDBClusterSnapshotIdentifier: aws.String("foo"), + TargetDBClusterSnapshotIdentifier: aws.String("bar"), + }) + return req + }(), + Assert: assertAsRegexMatch(fmt.Sprintf(regexPattern, + opCopyDBClusterSnapshot, targetRegion)), + }, + opCreateDBCluster: { + Req: func() *request.Request { + req, _ := svc.CreateDBClusterRequest( + &CreateDBClusterInput{ + SourceRegion: aws.String("us-west-1"), + DBClusterIdentifier: aws.String("foo"), + Engine: aws.String("bar"), + MasterUsername: aws.String("user"), + MasterUserPassword: aws.String("password"), + }) + return req + }(), + Assert: assertAsRegexMatch(fmt.Sprintf(regexPattern, + opCreateDBCluster, targetRegion)), + }, + opCopyDBClusterSnapshot + " same region": { + Req: func() *request.Request { + req, _ := svc.CopyDBClusterSnapshotRequest( + &CopyDBClusterSnapshotInput{ + SourceRegion: aws.String("us-west-2"), + SourceDBClusterSnapshotIdentifier: aws.String("foo"), + TargetDBClusterSnapshotIdentifier: aws.String("bar"), + }) + return req + }(), + Assert: assertAsEmpty(), + }, + opCreateDBCluster + " same region": { + Req: func() *request.Request { + req, _ := svc.CreateDBClusterRequest( + &CreateDBClusterInput{ + SourceRegion: aws.String("us-west-2"), + DBClusterIdentifier: aws.String("foo"), + Engine: aws.String("bar"), + MasterUsername: aws.String("user"), + MasterUserPassword: aws.String("password"), + }) + return req + }(), + Assert: assertAsEmpty(), + }, + opCopyDBClusterSnapshot + " presignURL set": { + Req: func() *request.Request { + req, _ := svc.CopyDBClusterSnapshotRequest( + &CopyDBClusterSnapshotInput{ + SourceRegion: aws.String("us-west-1"), + SourceDBClusterSnapshotIdentifier: aws.String("foo"), + TargetDBClusterSnapshotIdentifier: aws.String("bar"), + PreSignedUrl: aws.String("mockPresignedURL"), + }) + return req + }(), + Assert: assertAsEqual("mockPresignedURL"), + }, + opCreateDBCluster + " presignURL set": { + Req: func() *request.Request { + req, _ := svc.CreateDBClusterRequest( + &CreateDBClusterInput{ + SourceRegion: aws.String("us-west-1"), + DBClusterIdentifier: aws.String("foo"), + Engine: aws.String("bar"), + PreSignedUrl: aws.String("mockPresignedURL"), + MasterUsername: aws.String("user"), + MasterUserPassword: aws.String("password"), + }) + return req + }(), + Assert: assertAsEqual("mockPresignedURL"), + }, + } + + for name, c := range cases { + t.Run(name, func(t *testing.T) { + if err := c.Req.Sign(); err != nil { + t.Fatalf("expect no error, got %v", err) + } + b, _ := ioutil.ReadAll(c.Req.HTTPRequest.Body) + q, _ := url.ParseQuery(string(b)) + + u, _ := url.QueryUnescape(q.Get("PreSignedUrl")) + + c.Assert(t, u) + }) + } +} + +func TestPresignWithSourceNotSet(t *testing.T) { + reqs := map[string]*request.Request{} + svc := New(unit.Session, &aws.Config{Region: aws.String("us-west-2")}) + + reqs[opCopyDBClusterSnapshot], _ = svc.CopyDBClusterSnapshotRequest(&CopyDBClusterSnapshotInput{ + SourceDBClusterSnapshotIdentifier: aws.String("foo"), + TargetDBClusterSnapshotIdentifier: aws.String("bar"), + }) + + for _, req := range reqs { + _, err := req.Presign(5 * time.Minute) + if err != nil { + t.Fatal(err) + } + } +} + +func assertAsRegexMatch(exp string) func(*testing.T, string) { + return func(t *testing.T, v string) { + t.Helper() + + if re, a := regexp.MustCompile(exp), v; !re.MatchString(a) { + t.Errorf("expect %s to match %s", re, a) + } + } +} + +func assertAsEmpty() func(*testing.T, string) { + return func(t *testing.T, v string) { + t.Helper() + + if len(v) != 0 { + t.Errorf("expect empty, got %v", v) + } + } +} + +func assertAsEqual(expect string) func(*testing.T, string) { + return func(t *testing.T, v string) { + t.Helper() + + if e, a := expect, v; e != a { + t.Errorf("expect %v, got %v", e, a) + } + } +} diff --git a/service/neptune/api.go b/service/neptune/api.go index 80c4cdcc56a..cd0d49a9f12 100644 --- a/service/neptune/api.go +++ b/service/neptune/api.go @@ -6981,6 +6981,9 @@ type CopyDBClusterSnapshotInput struct { // cluster snapshot, and otherwise false. The default is false. CopyTags *bool `type:"boolean"` + // DestinationRegion is used for presigning the request to a given region. + DestinationRegion *string `type:"string"` + // The AWS AWS KMS key ID for an encrypted DB cluster snapshot. The KMS key // ID is the Amazon Resource Name (ARN), KMS key identifier, or the KMS key // alias for the KMS encryption key. @@ -7021,6 +7024,11 @@ type CopyDBClusterSnapshotInput struct { // SourceDBClusterSnapshotIdentifier is a required field SourceDBClusterSnapshotIdentifier *string `type:"string" required:"true"` + // SourceRegion is the source region where the resource exists. This is not + // sent over the wire and is only used for presigning. This value should always + // have the same region as the source ARN. + SourceRegion *string `type:"string" ignore:"true"` + // The tags to assign to the new DB cluster snapshot copy. Tags []*Tag `locationNameList:"Tag" type:"list"` @@ -7073,6 +7081,12 @@ func (s *CopyDBClusterSnapshotInput) SetCopyTags(v bool) *CopyDBClusterSnapshotI return s } +// SetDestinationRegion sets the DestinationRegion field's value. +func (s *CopyDBClusterSnapshotInput) SetDestinationRegion(v string) *CopyDBClusterSnapshotInput { + s.DestinationRegion = &v + return s +} + // SetKmsKeyId sets the KmsKeyId field's value. func (s *CopyDBClusterSnapshotInput) SetKmsKeyId(v string) *CopyDBClusterSnapshotInput { s.KmsKeyId = &v @@ -7091,6 +7105,12 @@ func (s *CopyDBClusterSnapshotInput) SetSourceDBClusterSnapshotIdentifier(v stri return s } +// SetSourceRegion sets the SourceRegion field's value. +func (s *CopyDBClusterSnapshotInput) SetSourceRegion(v string) *CopyDBClusterSnapshotInput { + s.SourceRegion = &v + return s +} + // SetTags sets the Tags field's value. func (s *CopyDBClusterSnapshotInput) SetTags(v []*Tag) *CopyDBClusterSnapshotInput { s.Tags = v @@ -7534,6 +7554,9 @@ type CreateDBClusterInput struct { // deletion protection is enabled. DeletionProtection *bool `type:"boolean"` + // DestinationRegion is used for presigning the request to a given region. + DestinationRegion *string `type:"string"` + // The list of log types that need to be enabled for exporting to CloudWatch // Logs. EnableCloudwatchLogsExports []*string `type:"list"` @@ -7645,6 +7668,11 @@ type CreateDBClusterInput struct { // this DB cluster is created as a Read Replica. ReplicationSourceIdentifier *string `type:"string"` + // SourceRegion is the source region where the resource exists. This is not + // sent over the wire and is only used for presigning. This value should always + // have the same region as the source ARN. + SourceRegion *string `type:"string" ignore:"true"` + // Specifies whether the DB cluster is encrypted. StorageEncrypted *bool `type:"boolean"` @@ -7729,6 +7757,12 @@ func (s *CreateDBClusterInput) SetDeletionProtection(v bool) *CreateDBClusterInp return s } +// SetDestinationRegion sets the DestinationRegion field's value. +func (s *CreateDBClusterInput) SetDestinationRegion(v string) *CreateDBClusterInput { + s.DestinationRegion = &v + return s +} + // SetEnableCloudwatchLogsExports sets the EnableCloudwatchLogsExports field's value. func (s *CreateDBClusterInput) SetEnableCloudwatchLogsExports(v []*string) *CreateDBClusterInput { s.EnableCloudwatchLogsExports = v @@ -7807,6 +7841,12 @@ func (s *CreateDBClusterInput) SetReplicationSourceIdentifier(v string) *CreateD return s } +// SetSourceRegion sets the SourceRegion field's value. +func (s *CreateDBClusterInput) SetSourceRegion(v string) *CreateDBClusterInput { + s.SourceRegion = &v + return s +} + // SetStorageEncrypted sets the StorageEncrypted field's value. func (s *CreateDBClusterInput) SetStorageEncrypted(v bool) *CreateDBClusterInput { s.StorageEncrypted = &v diff --git a/service/neptune/customizations.go b/service/neptune/customizations.go new file mode 100644 index 00000000000..b325229f396 --- /dev/null +++ b/service/neptune/customizations.go @@ -0,0 +1,107 @@ +package neptune + +import ( + "time" + + "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go/aws/awsutil" + "github.com/aws/aws-sdk-go/aws/endpoints" + "github.com/aws/aws-sdk-go/aws/request" +) + +func init() { + ops := []string{ + opCopyDBClusterSnapshot, + opCreateDBCluster, + } + initRequest = func(r *request.Request) { + for _, operation := range ops { + if r.Operation.Name == operation { + r.Handlers.Build.PushFront(fillPresignedURL) + } + } + } +} + +func fillPresignedURL(r *request.Request) { + fns := map[string]func(r *request.Request){ + opCopyDBClusterSnapshot: copyDBClusterSnapshotPresign, + opCreateDBCluster: createDBClusterPresign, + } + if !r.ParamsFilled() { + return + } + if f, ok := fns[r.Operation.Name]; ok { + f(r) + } +} + +func copyDBClusterSnapshotPresign(r *request.Request) { + originParams := r.Params.(*CopyDBClusterSnapshotInput) + + if originParams.SourceRegion == nil || originParams.PreSignedUrl != nil || originParams.DestinationRegion != nil { + return + } + + originParams.DestinationRegion = r.Config.Region + // preSignedUrl is not required for instances in the same region. + if *originParams.SourceRegion == *originParams.DestinationRegion { + return + } + + newParams := awsutil.CopyOf(r.Params).(*CopyDBClusterSnapshotInput) + originParams.PreSignedUrl = presignURL(r, originParams.SourceRegion, newParams) +} + +func createDBClusterPresign(r *request.Request) { + originParams := r.Params.(*CreateDBClusterInput) + + if originParams.SourceRegion == nil || originParams.PreSignedUrl != nil || originParams.DestinationRegion != nil { + return + } + + originParams.DestinationRegion = r.Config.Region + // preSignedUrl is not required for instances in the same region. + if *originParams.SourceRegion == *originParams.DestinationRegion { + return + } + + newParams := awsutil.CopyOf(r.Params).(*CreateDBClusterInput) + originParams.PreSignedUrl = presignURL(r, originParams.SourceRegion, newParams) +} + +// presignURL will presign the request by using SoureRegion to sign with. SourceRegion is not +// sent to the service, and is only used to not have the SDKs parsing ARNs. +func presignURL(r *request.Request, sourceRegion *string, newParams interface{}) *string { + cfg := r.Config.Copy(aws.NewConfig(). + WithEndpoint(""). + WithRegion(aws.StringValue(sourceRegion))) + + clientInfo := r.ClientInfo + resolved, err := r.Config.EndpointResolver.EndpointFor( + EndpointsID, aws.StringValue(cfg.Region), + func(opt *endpoints.Options) { + opt.DisableSSL = aws.BoolValue(cfg.DisableSSL) + opt.UseDualStack = aws.BoolValue(cfg.UseDualStack) + }, + ) + if err != nil { + r.Error = err + return nil + } + + clientInfo.Endpoint = resolved.URL + clientInfo.SigningRegion = resolved.SigningRegion + + // Presign a request with modified params + req := request.New(*cfg, clientInfo, r.Handlers, r.Retryer, r.Operation, newParams, r.Data) + req.Operation.HTTPMethod = "GET" + uri, err := req.Presign(5 * time.Minute) // 5 minutes should be enough. + if err != nil { // bubble error back up to original request + r.Error = err + return nil + } + + // We have our URL, set it on params + return &uri +} diff --git a/service/neptune/customizations_test.go b/service/neptune/customizations_test.go new file mode 100644 index 00000000000..e192bc22c0d --- /dev/null +++ b/service/neptune/customizations_test.go @@ -0,0 +1,181 @@ +// +build go1.9 + +package neptune + +import ( + "fmt" + "io/ioutil" + "net/url" + "regexp" + "testing" + "time" + + "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go/aws/request" + "github.com/aws/aws-sdk-go/awstesting" + "github.com/aws/aws-sdk-go/awstesting/unit" +) + +func TestCopyDBClusterSnapshotRequestNoPanic(t *testing.T) { + svc := New(unit.Session, &aws.Config{Region: aws.String("us-west-2")}) + + f := func() { + // Doesn't panic on nil input + req, _ := svc.CopyDBClusterSnapshotRequest(nil) + req.Sign() + } + if paniced, p := awstesting.DidPanic(f); paniced { + t.Errorf("expect no panic, got %v", p) + } +} + +func TestPresignCrossRegionRequest(t *testing.T) { + const targetRegion = "us-west-2" + + svc := New(unit.Session, &aws.Config{Region: aws.String(targetRegion)}) + + const regexPattern = `^https://rds.us-west-1\.amazonaws\.com/\?Action=%s.+?DestinationRegion=%s.+` + + cases := map[string]struct { + Req *request.Request + Assert func(*testing.T, string) + }{ + opCopyDBClusterSnapshot: { + Req: func() *request.Request { + req, _ := svc.CopyDBClusterSnapshotRequest( + &CopyDBClusterSnapshotInput{ + SourceRegion: aws.String("us-west-1"), + SourceDBClusterSnapshotIdentifier: aws.String("foo"), + TargetDBClusterSnapshotIdentifier: aws.String("bar"), + }) + return req + }(), + Assert: assertAsRegexMatch(fmt.Sprintf(regexPattern, + opCopyDBClusterSnapshot, targetRegion)), + }, + opCreateDBCluster: { + Req: func() *request.Request { + req, _ := svc.CreateDBClusterRequest( + &CreateDBClusterInput{ + SourceRegion: aws.String("us-west-1"), + DBClusterIdentifier: aws.String("foo"), + Engine: aws.String("bar"), + }) + return req + }(), + Assert: assertAsRegexMatch(fmt.Sprintf(regexPattern, + opCreateDBCluster, targetRegion)), + }, + opCopyDBClusterSnapshot + " same region": { + Req: func() *request.Request { + req, _ := svc.CopyDBClusterSnapshotRequest( + &CopyDBClusterSnapshotInput{ + SourceRegion: aws.String("us-west-2"), + SourceDBClusterSnapshotIdentifier: aws.String("foo"), + TargetDBClusterSnapshotIdentifier: aws.String("bar"), + }) + return req + }(), + Assert: assertAsEmpty(), + }, + opCreateDBCluster + " same region": { + Req: func() *request.Request { + req, _ := svc.CreateDBClusterRequest( + &CreateDBClusterInput{ + SourceRegion: aws.String("us-west-2"), + DBClusterIdentifier: aws.String("foo"), + Engine: aws.String("bar"), + }) + return req + }(), + Assert: assertAsEmpty(), + }, + opCopyDBClusterSnapshot + " presignURL set": { + Req: func() *request.Request { + req, _ := svc.CopyDBClusterSnapshotRequest( + &CopyDBClusterSnapshotInput{ + SourceRegion: aws.String("us-west-1"), + SourceDBClusterSnapshotIdentifier: aws.String("foo"), + TargetDBClusterSnapshotIdentifier: aws.String("bar"), + PreSignedUrl: aws.String("mockPresignedURL"), + }) + return req + }(), + Assert: assertAsEqual("mockPresignedURL"), + }, + opCreateDBCluster + " presignURL set": { + Req: func() *request.Request { + req, _ := svc.CreateDBClusterRequest( + &CreateDBClusterInput{ + SourceRegion: aws.String("us-west-1"), + DBClusterIdentifier: aws.String("foo"), + Engine: aws.String("bar"), + PreSignedUrl: aws.String("mockPresignedURL"), + }) + return req + }(), + Assert: assertAsEqual("mockPresignedURL"), + }, + } + + for name, c := range cases { + t.Run(name, func(t *testing.T) { + if err := c.Req.Sign(); err != nil { + t.Fatalf("expect no error, got %v", err) + } + b, _ := ioutil.ReadAll(c.Req.HTTPRequest.Body) + q, _ := url.ParseQuery(string(b)) + + u, _ := url.QueryUnescape(q.Get("PreSignedUrl")) + + c.Assert(t, u) + }) + } +} + +func TestPresignWithSourceNotSet(t *testing.T) { + reqs := map[string]*request.Request{} + svc := New(unit.Session, &aws.Config{Region: aws.String("us-west-2")}) + + reqs[opCopyDBClusterSnapshot], _ = svc.CopyDBClusterSnapshotRequest(&CopyDBClusterSnapshotInput{ + SourceDBClusterSnapshotIdentifier: aws.String("foo"), + TargetDBClusterSnapshotIdentifier: aws.String("bar"), + }) + + for _, req := range reqs { + _, err := req.Presign(5 * time.Minute) + if err != nil { + t.Fatal(err) + } + } +} + +func assertAsRegexMatch(exp string) func(*testing.T, string) { + return func(t *testing.T, v string) { + t.Helper() + + if re, a := regexp.MustCompile(exp), v; !re.MatchString(a) { + t.Errorf("expect %s to match %s", re, a) + } + } +} + +func assertAsEmpty() func(*testing.T, string) { + return func(t *testing.T, v string) { + t.Helper() + + if len(v) != 0 { + t.Errorf("expect empty, got %v", v) + } + } +} + +func assertAsEqual(expect string) func(*testing.T, string) { + return func(t *testing.T, v string) { + t.Helper() + + if e, a := expect, v; e != a { + t.Errorf("expect %v, got %v", e, a) + } + } +} diff --git a/service/rds/customizations.go b/service/rds/customizations.go index d6d2e5f6021..01d381a717e 100644 --- a/service/rds/customizations.go +++ b/service/rds/customizations.go @@ -137,7 +137,7 @@ func presignURL(r *request.Request, sourceRegion *string, newParams interface{}) clientInfo := r.ClientInfo resolved, err := r.Config.EndpointResolver.EndpointFor( - clientInfo.ServiceName, aws.StringValue(cfg.Region), + EndpointsID, aws.StringValue(cfg.Region), func(opt *endpoints.Options) { opt.DisableSSL = aws.BoolValue(cfg.DisableSSL) opt.UseDualStack = aws.BoolValue(cfg.UseDualStack)