Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Migrate S3 to SDK v2 #779

Merged
merged 2 commits into from
Nov 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
204 changes: 106 additions & 98 deletions aws/resources/s3.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,18 @@ package resources

import (
"context"
goerr "errors"
"fmt"
"math"
"strings"
"sync"
"time"

"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/service/s3"
"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/service/s3"
"github.com/aws/aws-sdk-go-v2/service/s3/types"
"github.com/aws/smithy-go"

"github.com/gruntwork-io/cloud-nuke/config"
"github.com/gruntwork-io/cloud-nuke/logging"
"github.com/gruntwork-io/cloud-nuke/report"
Expand All @@ -26,17 +30,17 @@ func (sb S3Buckets) getS3BucketRegion(bucketName string) (string, error) {
Bucket: aws.String(bucketName),
}

result, err := sb.Client.GetBucketLocationWithContext(sb.Context, input)
result, err := sb.Client.GetBucketLocation(sb.Context, input)
if err != nil {
return "", err
}

if result.LocationConstraint == nil {
if result.LocationConstraint == "" {
// GetBucketLocation returns nil for us-east-1
// https://github.com/aws/aws-sdk-go/issues/1687
return "us-east-1", nil
}
return *result.LocationConstraint, nil
return string(result.LocationConstraint), nil
}

// getS3BucketTags returns S3 Bucket tags.
Expand All @@ -47,18 +51,18 @@ func (bucket *S3Buckets) getS3BucketTags(bucketName string) (map[string]string,

// Please note that svc argument should be created from a session object which is
// in the same region as the bucket or GetBucketTagging will fail.
result, err := bucket.Client.GetBucketTaggingWithContext(bucket.Context, input)
result, err := bucket.Client.GetBucketTagging(bucket.Context, input)
if err != nil {
if aerr, ok := err.(awserr.Error); ok {
switch aerr.Code() {
case "NoSuchTagSet":
var apiErr *smithy.OperationError
if goerr.As(err, &apiErr) {
if strings.Contains(apiErr.Error(), "NoSuchTagSet: The TagSet does not exist") {
return nil, nil
}
return nil, err
}
return nil, err
}

return util.ConvertS3TagsToMap(result.TagSet), nil
return util.ConvertS3TypesTagsToMap(result.TagSet), nil
}

// S3Bucket - represents S3 bucket
Expand All @@ -73,7 +77,7 @@ type S3Bucket struct {

// getAllS3Buckets returns a map of per region AWS S3 buckets which were created before excludeAfter
func (sb S3Buckets) getAll(c context.Context, configObj config.Config) ([]*string, error) {
output, err := sb.Client.ListBucketsWithContext(sb.Context, &s3.ListBucketsInput{})
output, err := sb.Client.ListBuckets(sb.Context, &s3.ListBucketsInput{})
if err != nil {
return nil, errors.WithStackTrace(err)
}
Expand Down Expand Up @@ -106,14 +110,14 @@ func (sb S3Buckets) getAll(c context.Context, configObj config.Config) ([]*strin
}

// getBucketNamesPerRegions gets valid bucket names concurrently from list of target buckets
func (sb S3Buckets) getBucketNames(targetBuckets []*s3.Bucket, configObj config.Config) ([]*string, error) {
func (sb S3Buckets) getBucketNames(targetBuckets []types.Bucket, configObj config.Config) ([]*string, error) {
var bucketNames []*string
bucketCh := make(chan *S3Bucket, len(targetBuckets))
var wg sync.WaitGroup

for _, bucket := range targetBuckets {
wg.Add(1)
go func(bucket *s3.Bucket) {
go func(bucket types.Bucket) {
defer wg.Done()
sb.getBucketInfo(bucket, bucketCh, configObj)
}(bucket)
Expand Down Expand Up @@ -143,10 +147,10 @@ func (sb S3Buckets) getBucketNames(targetBuckets []*s3.Bucket, configObj config.
}

// getBucketInfo populates the local S3Bucket struct for the passed AWS bucket
func (sb S3Buckets) getBucketInfo(bucket *s3.Bucket, bucketCh chan<- *S3Bucket, configObj config.Config) {
func (sb S3Buckets) getBucketInfo(bucket types.Bucket, bucketCh chan<- *S3Bucket, configObj config.Config) {
var bucketData S3Bucket
bucketData.Name = aws.StringValue(bucket.Name)
bucketData.CreationDate = aws.TimeValue(bucket.CreationDate)
bucketData.Name = aws.ToString(bucket.Name)
bucketData.CreationDate = aws.ToTime(bucket.CreationDate)

bucketRegion, err := sb.getS3BucketRegion(bucketData.Name)
if err != nil {
Expand Down Expand Up @@ -194,93 +198,73 @@ func (sb S3Buckets) getBucketInfo(bucket *s3.Bucket, bucketCh chan<- *S3Bucket,
func (sb S3Buckets) emptyBucket(bucketName *string) error {
// Since the error may happen in the inner function handler for the pager, we need a function scoped variable that
// the inner function can set when there is an error.
var errOut error
pageId := 1

// As bucket versioning is managed separately and you can turn off versioning after the bucket is created,
// we need to check if there are any versions in the bucket regardless of the versioning status.
err := sb.Client.ListObjectVersionsPagesWithContext(
sb.Context,
&s3.ListObjectVersionsInput{
Bucket: bucketName,
MaxKeys: aws.Int64(int64(sb.MaxBatchSize())),
},
func(page *s3.ListObjectVersionsOutput, lastPage bool) (shouldContinue bool) {
logging.Debugf("Deleting page %d of object versions (%d objects) from bucket %s", pageId, len(page.Versions), aws.StringValue(bucketName))
if err := sb.deleteObjectVersions(bucketName, page.Versions); err != nil {
logging.Errorf("Error deleting objects versions for page %d from bucket %s: %s", pageId, aws.StringValue(bucketName), err)
errOut = err
return false
}
logging.Debugf("[OK] - deleted page %d of object versions (%d objects) from bucket %s", pageId, len(page.Versions), aws.StringValue(bucketName))

logging.Debugf("Deleting page %d of deletion markers (%d deletion markers) from bucket %s", pageId, len(page.DeleteMarkers), aws.StringValue(bucketName))
if err := sb.deleteDeletionMarkers(bucketName, page.DeleteMarkers); err != nil {
logging.Debugf("Error deleting deletion markers for page %d from bucket %s: %s", pageId, aws.StringValue(bucketName), err)
errOut = err
return false
}
logging.Debugf("[OK] - deleted page %d of deletion markers (%d deletion markers) from bucket %s", pageId, len(page.DeleteMarkers), aws.StringValue(bucketName))

pageId++
return true
},
)
outputs, err := sb.Client.ListObjectVersions(sb.Context, &s3.ListObjectVersionsInput{
Bucket: bucketName,
MaxKeys: aws.Int32(int32(sb.MaxBatchSize())),
})
if err != nil {
return err
}
if errOut != nil {
return errOut
return errors.WithStackTrace(err)
}
return nil

// Handle non versioned buckets.
err = sb.Client.ListObjectsV2PagesWithContext(
sb.Context,
&s3.ListObjectsV2Input{
Bucket: bucketName,
MaxKeys: aws.Int64(int64(sb.MaxBatchSize())),
},
func(page *s3.ListObjectsV2Output, lastPage bool) (shouldContinue bool) {
logging.Debugf("Deleting object page %d (%d objects) from bucket %s", pageId, len(page.Contents), aws.StringValue(bucketName))
if err := sb.deleteObjects(bucketName, page.Contents); err != nil {
logging.Errorf("Error deleting objects for page %d from bucket %s: %s", pageId, aws.StringValue(bucketName), err)
errOut = err
return false
}
logging.Debugf("[OK] - deleted object page %d (%d objects) from bucket %s", pageId, len(page.Contents), aws.StringValue(bucketName))
logging.Debugf("Deleting page %d of object versions (%d objects) from bucket %s", pageId, len(outputs.Versions), aws.ToString(bucketName))
if err := sb.deleteObjectVersions(bucketName, outputs.Versions); err != nil {
logging.Errorf("Error deleting objects versions for page %d from bucket %s: %s", pageId, aws.ToString(bucketName), err)
return errors.WithStackTrace(err)
}
logging.Debugf("[OK] - deleted page %d of object versions (%d objects) from bucket %s", pageId, len(outputs.Versions), aws.ToString(bucketName))

pageId++
return true
},
)
if err != nil {
return err
logging.Debugf("Deleting page %d of object delete markers (%d objects) from bucket %s", pageId, len(outputs.Versions), aws.ToString(bucketName))
if err := sb.deleteDeletionMarkers(bucketName, outputs.DeleteMarkers); err != nil {
logging.Errorf("Error deleting deletion markers for page %d from bucket %s: %s", pageId, aws.ToString(bucketName), err)
return errors.WithStackTrace(err)
}
if errOut != nil {
return errOut
logging.Debugf("[OK] - deleted page %d of deletion markers (%d deletion markers) from bucket %s", pageId, len(outputs.DeleteMarkers), aws.ToString(bucketName))

paginator := s3.NewListObjectsV2Paginator(sb.Client, &s3.ListObjectsV2Input{
Bucket: bucketName,
MaxKeys: aws.Int32(int32(sb.MaxBatchSize())),
})

for paginator.HasMorePages() {

page, err := paginator.NextPage(sb.Context)
if err != nil {
return errors.WithStackTrace(err)
}

logging.Debugf("Deleting object page %d (%d objects) from bucket %s", pageId, len(page.Contents), aws.ToString(bucketName))
if err := sb.deleteObjects(bucketName, page.Contents); err != nil {
logging.Errorf("Error deleting objects for page %d from bucket %s: %s", pageId, aws.ToString(bucketName), err)
return err
}
pageId++
}
return nil
}

// deleteObjects will delete the provided objects (unversioned) from the specified bucket.
func (sb S3Buckets) deleteObjects(bucketName *string, objects []*s3.Object) error {
func (sb S3Buckets) deleteObjects(bucketName *string, objects []types.Object) error {
if len(objects) == 0 {
logging.Debugf("No objects returned in page")
return nil
}

objectIdentifiers := []*s3.ObjectIdentifier{}
objectIdentifiers := []types.ObjectIdentifier{}
for _, obj := range objects {
objectIdentifiers = append(objectIdentifiers, &s3.ObjectIdentifier{
objectIdentifiers = append(objectIdentifiers, types.ObjectIdentifier{
Key: obj.Key,
})
}
_, err := sb.Client.DeleteObjectsWithContext(
_, err := sb.Client.DeleteObjects(
sb.Context,
&s3.DeleteObjectsInput{
Bucket: bucketName,
Delete: &s3.Delete{
Delete: &types.Delete{
Objects: objectIdentifiers,
Quiet: aws.Bool(false),
},
Expand All @@ -290,24 +274,24 @@ func (sb S3Buckets) deleteObjects(bucketName *string, objects []*s3.Object) erro
}

// deleteObjectVersions will delete the provided object versions from the specified bucket.
func (sb S3Buckets) deleteObjectVersions(bucketName *string, objectVersions []*s3.ObjectVersion) error {
func (sb S3Buckets) deleteObjectVersions(bucketName *string, objectVersions []types.ObjectVersion) error {
if len(objectVersions) == 0 {
logging.Debugf("No object versions returned in page")
return nil
}

objectIdentifiers := []*s3.ObjectIdentifier{}
objectIdentifiers := []types.ObjectIdentifier{}
for _, obj := range objectVersions {
objectIdentifiers = append(objectIdentifiers, &s3.ObjectIdentifier{
objectIdentifiers = append(objectIdentifiers, types.ObjectIdentifier{
Key: obj.Key,
VersionId: obj.VersionId,
})
}
_, err := sb.Client.DeleteObjectsWithContext(
_, err := sb.Client.DeleteObjects(
sb.Context,
&s3.DeleteObjectsInput{
Bucket: bucketName,
Delete: &s3.Delete{
Delete: &types.Delete{
Objects: objectIdentifiers,
Quiet: aws.Bool(false),
},
Expand All @@ -317,24 +301,24 @@ func (sb S3Buckets) deleteObjectVersions(bucketName *string, objectVersions []*s
}

// deleteDeletionMarkers will delete the provided deletion markers from the specified bucket.
func (sb S3Buckets) deleteDeletionMarkers(bucketName *string, objectDelMarkers []*s3.DeleteMarkerEntry) error {
func (sb S3Buckets) deleteDeletionMarkers(bucketName *string, objectDelMarkers []types.DeleteMarkerEntry) error {
if len(objectDelMarkers) == 0 {
logging.Debugf("No deletion markers returned in page")
return nil
}

objectIdentifiers := []*s3.ObjectIdentifier{}
objectIdentifiers := []types.ObjectIdentifier{}
for _, obj := range objectDelMarkers {
objectIdentifiers = append(objectIdentifiers, &s3.ObjectIdentifier{
objectIdentifiers = append(objectIdentifiers, types.ObjectIdentifier{
Key: obj.Key,
VersionId: obj.VersionId,
})
}
_, err := sb.Client.DeleteObjectsWithContext(
_, err := sb.Client.DeleteObjects(
sb.Context,
&s3.DeleteObjectsInput{
Bucket: bucketName,
Delete: &s3.Delete{
Delete: &types.Delete{
Objects: objectIdentifiers,
Quiet: aws.Bool(false),
},
Expand All @@ -349,18 +333,18 @@ func (sb S3Buckets) nukeAllS3BucketObjects(bucketName *string) error {
return fmt.Errorf("Invalid batchsize - %d - should be between %d and %d", sb.MaxBatchSize(), 1, 1000)
}

logging.Debugf("Emptying bucket %s", aws.StringValue(bucketName))
logging.Debugf("Emptying bucket %s", aws.ToString(bucketName))
if err := sb.emptyBucket(bucketName); err != nil {
return err
}
logging.Debugf("[OK] - successfully emptied bucket %s", aws.StringValue(bucketName))
logging.Debugf("[OK] - successfully emptied bucket %s", aws.ToString(bucketName))
return nil
}

// nukeEmptyS3Bucket deletes an empty S3 bucket
func (sb S3Buckets) nukeEmptyS3Bucket(bucketName *string, verifyBucketDeletion bool) error {

_, err := sb.Client.DeleteBucketWithContext(sb.Context, &s3.DeleteBucketInput{
_, err := sb.Client.DeleteBucket(sb.Context, &s3.DeleteBucketInput{
Bucket: bucketName,
})
if err != nil {
Expand All @@ -375,23 +359,47 @@ func (sb S3Buckets) nukeEmptyS3Bucket(bucketName *string, verifyBucketDeletion b
// such, we retry this routine up to 3 times for a total of 300 seconds.
const maxRetries = 3
for i := 0; i < maxRetries; i++ {
logging.Debugf("Waiting until bucket (%s) deletion is propagated (attempt %d / %d)", aws.StringValue(bucketName), i+1, maxRetries)
err = sb.Client.WaitUntilBucketNotExistsWithContext(sb.Context, &s3.HeadBucketInput{
Bucket: bucketName,
})
logging.Debugf("Waiting until bucket (%s) deletion is propagated (attempt %d / %d)", aws.ToString(bucketName), i+1, maxRetries)
err = waitForBucketDeletion(sb.Context, sb.Client, aws.ToString(bucketName))
// Exit early if no error
if err == nil {
logging.Debug("Successfully detected bucket deletion.")
return nil
}
logging.Debugf("Error waiting for bucket (%s) deletion propagation (attempt %d / %d)", aws.StringValue(bucketName), i+1, maxRetries)
logging.Debugf("Error waiting for bucket (%s) deletion propagation (attempt %d / %d)", aws.ToString(bucketName), i+1, maxRetries)
logging.Debugf("Underlying error was: %s", err)
}
return err
}

func waitForBucketDeletion(ctx context.Context, client S3API, bucketName string) error {
waiter := s3.NewBucketNotExistsWaiter(client)

for i := 0; i < maxRetries; i++ {
logging.Debugf("Waiting until bucket (%s) deletion is propagated (attempt %d / %d)", bucketName, i+1, maxRetries)

err := waiter.Wait(ctx, &s3.HeadBucketInput{
Bucket: aws.String(bucketName),
}, waitDuration)
if err == nil {
logging.Debugf("Successfully detected bucket deletion.")
return nil
}
logging.Debugf("Waiting until bucket erorr (%v)", err)

if i == maxRetries-1 {
return fmt.Errorf("failed to confirm bucket deletion after %d attempts: %w", maxRetries, err)
}

logging.Debugf("Error waiting for bucket (%s) deletion propagation (attempt %d / %d)", bucketName, i+1, maxRetries)
logging.Debugf("Underlying error was: %s", err)
}

return fmt.Errorf("unexpected error: reached end of retry loop")
}

func (sb S3Buckets) nukeS3BucketPolicy(bucketName *string) error {
_, err := sb.Client.DeleteBucketPolicyWithContext(
_, err := sb.Client.DeleteBucketPolicy(
sb.Context,
&s3.DeleteBucketPolicyInput{
Bucket: aws.String(*bucketName),
Expand Down Expand Up @@ -443,7 +451,7 @@ func (sb S3Buckets) nukeAll(bucketNames []*string) (delCount int, err error) {

// Record status of this resource
e := report.Entry{
Identifier: aws.StringValue(bucketName),
Identifier: aws.ToString(bucketName),
ResourceType: "S3 Bucket",
Error: err,
}
Expand Down
Loading