diff --git a/aws/resources/ecs_cluster.go b/aws/resources/ecs_cluster.go index 1f5c0cf7..66c41b06 100644 --- a/aws/resources/ecs_cluster.go +++ b/aws/resources/ecs_cluster.go @@ -2,9 +2,10 @@ package resources import ( "context" - "github.com/gruntwork-io/cloud-nuke/util" "time" + "github.com/gruntwork-io/cloud-nuke/util" + "github.com/gruntwork-io/cloud-nuke/telemetry" commonTelemetry "github.com/gruntwork-io/go-commons/telemetry" @@ -124,6 +125,34 @@ func (clusters *ECSClusters) getAll(c context.Context, configObj config.Config) return filteredEcsClusters, nil } +func (clusters *ECSClusters) stopClusterRunningTasks(clusterArn *string) error { + logging.Debugf("[TASK] stopping tasks running on cluster %v", *clusterArn) + // before deleting the cluster, remove the active tasks on that cluster + runningTasks, err := clusters.Client.ListTasks(&ecs.ListTasksInput{ + Cluster: clusterArn, + DesiredStatus: aws.String("RUNNING"), + }) + + if err != nil { + return errors.WithStackTrace(err) + } + + // stop the listed tasks + for _, task := range runningTasks.TaskArns { + _, err := clusters.Client.StopTask(&ecs.StopTaskInput{ + Cluster: clusterArn, + Task: task, + Reason: aws.String("Terminating task due to cluster deletion"), + }) + if err != nil { + logging.Debugf("[TASK] Unable to stop the task %s on cluster %s. Reason: %v", *task, *clusterArn, err) + return errors.WithStackTrace(err) + } + logging.Debugf("[TASK] Success, stopped task %v", *task) + } + return nil +} + func (clusters *ECSClusters) nukeAll(ecsClusterArns []*string) error { numNuking := len(ecsClusterArns) @@ -136,10 +165,18 @@ func (clusters *ECSClusters) nukeAll(ecsClusterArns []*string) error { var nukedEcsClusters []*string for _, clusterArn := range ecsClusterArns { + + // before nuking the clusters, do check active tasks on the cluster and stop all of them + err := clusters.stopClusterRunningTasks(clusterArn) + if err != nil { + logging.Debugf("Error, unable to stop the running stasks on the cluster %s %s", aws.StringValue(clusterArn), err) + return errors.WithStackTrace(err) + } + params := &ecs.DeleteClusterInput{ Cluster: clusterArn, } - _, err := clusters.Client.DeleteCluster(params) + _, err = clusters.Client.DeleteCluster(params) // Record status of this resource e := report.Entry{ @@ -150,7 +187,7 @@ func (clusters *ECSClusters) nukeAll(ecsClusterArns []*string) error { report.Record(e) if err != nil { - logging.Debugf("Error, failed to delete cluster with ARN %s", aws.StringValue(clusterArn)) + logging.Debugf("Error, failed to delete cluster with ARN %s %s", aws.StringValue(clusterArn), err) telemetry.TrackEvent(commonTelemetry.EventContext{ EventName: "Error Nuking ECS Cluster", }, map[string]interface{}{ diff --git a/aws/resources/ecs_cluster_test.go b/aws/resources/ecs_cluster_test.go index 9be2532c..da2cc990 100644 --- a/aws/resources/ecs_cluster_test.go +++ b/aws/resources/ecs_cluster_test.go @@ -2,6 +2,10 @@ package resources import ( "context" + "regexp" + "testing" + "time" + "github.com/aws/aws-sdk-go/aws" "github.com/aws/aws-sdk-go/service/ecs" "github.com/aws/aws-sdk-go/service/ecs/ecsiface" @@ -9,9 +13,6 @@ import ( "github.com/gruntwork-io/cloud-nuke/telemetry" "github.com/gruntwork-io/cloud-nuke/util" "github.com/stretchr/testify/require" - "regexp" - "testing" - "time" ) type mockedEC2Cluster struct { @@ -21,6 +22,8 @@ type mockedEC2Cluster struct { TagResourceOutput ecs.TagResourceOutput ListTagsForResourceOutput ecs.ListTagsForResourceOutput DeleteClusterOutput ecs.DeleteClusterOutput + ListTasksOutput ecs.ListTasksOutput + StopTaskOutput ecs.StopTaskOutput } func (m mockedEC2Cluster) ListClusters(*ecs.ListClustersInput) (*ecs.ListClustersOutput, error) { @@ -43,6 +46,13 @@ func (m mockedEC2Cluster) DeleteCluster(*ecs.DeleteClusterInput) (*ecs.DeleteClu return &m.DeleteClusterOutput, nil } +func (m mockedEC2Cluster) ListTasks(*ecs.ListTasksInput) (*ecs.ListTasksOutput, error) { + return &m.ListTasksOutput, nil +} +func (m mockedEC2Cluster) StopTask(*ecs.StopTaskInput) (*ecs.StopTaskOutput, error) { + return &m.StopTaskOutput, nil +} + func TestEC2Cluster_GetAll(t *testing.T) { telemetry.InitTelemetry("cloud-nuke", "") t.Parallel() @@ -84,6 +94,12 @@ func TestEC2Cluster_GetAll(t *testing.T) { }, }, }, + ListTasksOutput: ecs.ListTasksOutput{ + TaskArns: []*string{ + aws.String("task-arn-001"), + aws.String("task-arn-002"), + }, + }, }, } @@ -137,3 +153,23 @@ func TestEC2Cluster_NukeAll(t *testing.T) { err := ec.nukeAll([]*string{aws.String("arn:aws:ecs:us-east-1:123456789012:cluster/cluster1")}) require.NoError(t, err) } + +func TestEC2ClusterWithTasks_NukeAll(t *testing.T) { + telemetry.InitTelemetry("cloud-nuke", "") + t.Parallel() + + ec := ECSClusters{ + Client: mockedEC2Cluster{ + DeleteClusterOutput: ecs.DeleteClusterOutput{}, + ListTasksOutput: ecs.ListTasksOutput{ + TaskArns: []*string{ + aws.String("task-arn-001"), + aws.String("task-arn-002"), + }, + }, + }, + } + + err := ec.nukeAll([]*string{aws.String("arn:aws:ecs:us-east-1:123456789012:cluster/cluster1")}) + require.NoError(t, err) +}