Skip to content

Commit

Permalink
IND-1931 Vulnerability Remediation, aws-sdk-go migration from v1 to v2
Browse files Browse the repository at this point in the history
  • Loading branch information
mohanmanikanta2299 committed Jan 21, 2025
1 parent 5eb1507 commit b237a2f
Show file tree
Hide file tree
Showing 19 changed files with 244 additions and 180 deletions.
3 changes: 1 addition & 2 deletions cmd/discover/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ import (
"flag"
"fmt"
"io"
"io/ioutil"
"log"
"os"
"strings"
Expand All @@ -32,7 +31,7 @@ func main() {

var w io.Writer = os.Stderr
if quiet {
w = ioutil.Discard
w = io.Discard
}
l := log.New(w, "", 0)

Expand Down
35 changes: 26 additions & 9 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,11 @@ require (
github.com/Azure/azure-sdk-for-go v44.0.0+incompatible
github.com/Azure/go-autorest/autorest v0.11.18
github.com/Azure/go-autorest/autorest/azure/auth v0.5.0
github.com/aws/aws-sdk-go v1.44.262
github.com/aws/aws-sdk-go-v2/service/ec2 v1.200.0
github.com/denverdino/aliyungo v0.0.0-20170926055100-d3308649c661
github.com/digitalocean/godo v1.7.5
github.com/gophercloud/gophercloud v0.1.0
github.com/hashicorp/go-discover/provider/gce v0.0.0-20240829171124-547b9abd20f6
github.com/hashicorp/go-discover/provider/gce v0.0.0-20241120163552-5eb1507d16b4
github.com/hashicorp/go-multierror v1.0.0
github.com/hashicorp/mdns v1.0.1
github.com/hashicorp/vic v1.5.1-0.20190403131502-bbfe86ec9443
Expand All @@ -28,6 +28,18 @@ require (
k8s.io/client-go v0.22.2
)

require (
github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.28 // indirect
github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.28 // indirect
github.com/aws/aws-sdk-go-v2/internal/ini v1.8.1 // indirect
github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.12.1 // indirect
github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.12.9 // indirect
github.com/aws/aws-sdk-go-v2/service/sso v1.24.11 // indirect
github.com/aws/aws-sdk-go-v2/service/ssooidc v1.28.10 // indirect
github.com/aws/aws-sdk-go-v2/service/sts v1.33.9 // indirect
github.com/aws/smithy-go v1.22.1 // indirect
)

require (
cloud.google.com/go/auth v0.9.1 // indirect
cloud.google.com/go/auth/oauth2adapt v0.2.4 // indirect
Expand All @@ -41,6 +53,11 @@ require (
github.com/Azure/go-autorest/logger v0.2.1 // indirect
github.com/Azure/go-autorest/tracing v0.6.0 // indirect
github.com/abdullin/seq v0.0.0-20160510034733-d5467c17e7af // indirect
github.com/aws/aws-sdk-go-v2 v1.33.0
github.com/aws/aws-sdk-go-v2/config v1.29.1
github.com/aws/aws-sdk-go-v2/credentials v1.17.54
github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.24
github.com/aws/aws-sdk-go-v2/service/ecs v1.53.8
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/dimchansky/utfbom v1.1.0 // indirect
github.com/dnaeon/go-vcr v1.0.1 // indirect
Expand Down Expand Up @@ -78,13 +95,13 @@ require (
go.opentelemetry.io/otel v1.24.0 // indirect
go.opentelemetry.io/otel/metric v1.24.0 // indirect
go.opentelemetry.io/otel/trace v1.24.0 // indirect
golang.org/x/crypto v0.26.0 // indirect
golang.org/x/crypto v0.31.0 // indirect
golang.org/x/mod v0.17.0 // indirect
golang.org/x/net v0.28.0 // indirect
golang.org/x/sync v0.8.0 // indirect
golang.org/x/sys v0.24.0 // indirect
golang.org/x/term v0.23.0 // indirect
golang.org/x/text v0.17.0 // indirect
golang.org/x/net v0.33.0 // indirect
golang.org/x/sync v0.10.0 // indirect
golang.org/x/sys v0.28.0 // indirect
golang.org/x/term v0.27.0 // indirect
golang.org/x/text v0.21.0 // indirect
golang.org/x/time v0.6.0 // indirect
golang.org/x/tools v0.21.1-0.20240508182429-e35e4ccd0d2d // indirect
google.golang.org/api v0.195.0 // indirect
Expand All @@ -98,7 +115,7 @@ require (
gopkg.in/resty.v1 v1.12.0 // indirect
gopkg.in/yaml.v2 v2.4.0 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
k8s.io/klog/v2 v2.9.0 // indirect
k8s.io/klog/v2 v2.130.1 // indirect
k8s.io/utils v0.0.0-20210819203725-bdf08cb9a70a // indirect
sigs.k8s.io/structured-merge-diff/v4 v4.1.2 // indirect
sigs.k8s.io/yaml v1.2.0 // indirect
Expand Down
79 changes: 46 additions & 33 deletions go.sum

Large diffs are not rendered by default.

4 changes: 2 additions & 2 deletions provider/aliyun/aliyun_discover.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ package aliyun

import (
"fmt"
"io/ioutil"
"io"
"log"

"github.com/denverdino/aliyungo/common"
Expand Down Expand Up @@ -39,7 +39,7 @@ func (p *Provider) Addrs(args map[string]string, l *log.Logger) ([]string, error
}

if l == nil {
l = log.New(ioutil.Discard, "", 0)
l = log.New(io.Discard, "", 0)
}

region := args["region"]
Expand Down
149 changes: 79 additions & 70 deletions provider/aws/aws_discover.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,21 +2,23 @@
package aws

import (
"context"
"encoding/json"
"fmt"
"github.com/aws/aws-sdk-go/aws/arn"
"github.com/aws/aws-sdk-go/service/ecs"
"io/ioutil"
"io"
"log"
"net/http"
"os"

"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/credentials"
"github.com/aws/aws-sdk-go/aws/defaults"
"github.com/aws/aws-sdk-go/aws/ec2metadata"
"github.com/aws/aws-sdk-go/aws/session"
"github.com/aws/aws-sdk-go/service/ec2"
"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/aws/arn"
"github.com/aws/aws-sdk-go-v2/config"
"github.com/aws/aws-sdk-go-v2/credentials"
"github.com/aws/aws-sdk-go-v2/feature/ec2/imds"
"github.com/aws/aws-sdk-go-v2/service/ec2"
"github.com/aws/aws-sdk-go-v2/service/ec2/types"
"github.com/aws/aws-sdk-go-v2/service/ecs"
ecstypes "github.com/aws/aws-sdk-go-v2/service/ecs/types"
)

type Provider struct{}
Expand Down Expand Up @@ -61,11 +63,11 @@ func (p *Provider) Help() string {

func (p *Provider) Addrs(args map[string]string, l *log.Logger) ([]string, error) {
if args["provider"] != "aws" {
return nil, fmt.Errorf("discover-aws: invalid provider " + args["provider"])
return nil, fmt.Errorf("%s", "discover-aws: invalid provider "+args["provider"])
}

if l == nil {
l = log.New(ioutil.Discard, "", 0)
l = log.New(io.Discard, "", 0)
}

region := args["region"]
Expand Down Expand Up @@ -122,8 +124,8 @@ func (p *Provider) Addrs(args map[string]string, l *log.Logger) ([]string, error
}
} else {
l.Printf("[INFO] discover-aws: Region not provided. Looking up region in ec2 metadata...")
ec2meta := ec2metadata.New(session.New())
identity, err := ec2meta.GetInstanceIdentityDocument()
ec2meta := imds.New(imds.Options{})
identity, err := ec2meta.GetInstanceIdentityDocument(context.TODO(), &imds.GetInstanceIdentityDocumentInput{})
if err != nil {
return nil, fmt.Errorf("discover-aws: GetInstanceIdentityDocument failed: %s", err)
}
Expand All @@ -133,33 +135,36 @@ func (p *Provider) Addrs(args map[string]string, l *log.Logger) ([]string, error
l.Printf("[INFO] discover-aws: Region is %s", region)

l.Printf("[DEBUG] discover-aws: Creating session...")
config := aws.Config{
Region: &region,
Credentials: credentials.NewChainCredentials(
[]credentials.Provider{
&credentials.StaticProvider{
Value: credentials.Value{
AccessKeyID: accessKey,
SecretAccessKey: secretKey,
SessionToken: sessionToken,
},
},
&credentials.EnvProvider{},
&credentials.SharedCredentialsProvider{},
defaults.RemoteCredProvider(*(defaults.Config()), defaults.Handlers()),
}),
}
if endpoint != "" {
l.Printf("[INFO] discover-aws: Endpoint is %s", endpoint)
config.Endpoint = &endpoint
var cfg aws.Config
var err error
if accessKey != "" && secretKey != "" {
log.Println("Using static credentials provider")
staticCreds := credentials.NewStaticCredentialsProvider(accessKey, secretKey, sessionToken)
cfg, err = config.LoadDefaultConfig(context.TODO(), config.WithRegion(region),
config.WithCredentialsProvider(aws.NewCredentialsCache(staticCreds)))
if err != nil {
l.Printf("unable to load SDK config with Static Provider, %v", err)
}
} else {
log.Println("Using default credential chain")
cfg, err = config.LoadDefaultConfig(context.TODO(),
config.WithRegion(region), // Specify your region
)
if err != nil {
return nil, fmt.Errorf("unable to load SDK config with default credential chain, %s", err)
}
}

// Split here for ec2 vs ecs decision tree
if service == "ecs" {
svc := ecs.New(session.New(), &config)
svc := ecs.NewFromConfig(cfg, func(o *ecs.Options) {
if endpoint != "" {
o.BaseEndpoint = aws.String(endpoint)
}
})

log.Printf("[INFO] discover-aws: Filter ECS tasks with %s=%s", tagKey, tagValue)
var clusterArns []*string
var clusterArns []string

// If an ECS Cluster Name (ARN) was specified, dont lookup all the cluster arns
if ecsCluster == "" {
Expand All @@ -169,12 +174,12 @@ func (p *Provider) Addrs(args map[string]string, l *log.Logger) ([]string, error
}
clusterArns = arns
} else {
clusterArns = []*string{&ecsCluster}
clusterArns = []string{ecsCluster}
}

var taskIps []string
for _, clusterArn := range clusterArns {
taskArns, err := getEcsTasks(svc, clusterArn, &ecsFamily)
taskArns, err := getEcsTasks(svc, &clusterArn, &ecsFamily)
if err != nil {
return nil, fmt.Errorf("discover-aws: Failed to get ECS Tasks: %s", err)
}
Expand All @@ -185,7 +190,7 @@ func (p *Provider) Addrs(args map[string]string, l *log.Logger) ([]string, error
pageLimit := 100
for i := 0; i < len(taskArns); i += pageLimit {
taskGroup := taskArns[i:min(i+pageLimit, len(taskArns))]
ecsTaskIps, err := getEcsTaskIps(svc, clusterArn, taskGroup, &tagKey, &tagValue)
ecsTaskIps, err := getEcsTaskIps(svc, &clusterArn, taskGroup, &tagKey, &tagValue)
if err != nil {
return nil, fmt.Errorf("discover-aws: Failed to get ECS Task IPs: %s", err)
}
Expand All @@ -199,18 +204,22 @@ func (p *Provider) Addrs(args map[string]string, l *log.Logger) ([]string, error

// When not using ECS continue with the default EC2 search

svc := ec2.New(session.New(), &config)
svc := ec2.NewFromConfig(cfg, func(o *ec2.Options) {
if endpoint != "" {
o.BaseEndpoint = aws.String(endpoint)
}
})

l.Printf("[INFO] discover-aws: Filter instances with %s=%s", tagKey, tagValue)
resp, err := svc.DescribeInstances(&ec2.DescribeInstancesInput{
Filters: []*ec2.Filter{
&ec2.Filter{
resp, err := svc.DescribeInstances(context.TODO(), &ec2.DescribeInstancesInput{
Filters: []types.Filter{
{
Name: aws.String("tag:" + tagKey),
Values: []*string{aws.String(tagValue)},
Values: []string{tagValue},
},
&ec2.Filter{
{
Name: aws.String("instance-state-name"),
Values: []*string{aws.String("running")},
Values: []string{"running"},
},
},
})
Expand Down Expand Up @@ -276,18 +285,17 @@ func min(a, b int) int {
return b
}

func getEcsClusters(svc *ecs.ECS) ([]*string, error) {
pageNum := 0
var clusterArns []*string
err := svc.ListClustersPages(&ecs.ListClustersInput{}, func(page *ecs.ListClustersOutput, lastPage bool) bool {
pageNum++
clusterArns = append(clusterArns, page.ClusterArns...)
log.Printf("[DEBUG] discover-aws: Retrieved %d TaskArns from page %d", len(clusterArns), pageNum)
return !lastPage // return false to exit page function
})
func getEcsClusters(svc *ecs.Client) ([]string, error) {
var clusterArns []string
paginator := ecs.NewListClustersPaginator(svc, &ecs.ListClustersInput{})

if err != nil {
return nil, fmt.Errorf("ListClusters failed: %s", err)
for paginator.HasMorePages() {
page, err := paginator.NextPage(context.TODO())
if err != nil {
return nil, fmt.Errorf("ListClusters failed: %s", err)
}
clusterArns = append(clusterArns, page.ClusterArns...)
log.Printf("[DEBUG] discover-aws: Retrieved %d ClusterArns", len(clusterArns))
}

return clusterArns, nil
Expand All @@ -304,7 +312,7 @@ func getECSTaskMetadata() (ECSTaskMeta, error) {
if err != nil {
return metadataResp, fmt.Errorf("calling metadata uri: %s", err)
}
respBytes, err := ioutil.ReadAll(resp.Body)
respBytes, err := io.ReadAll(resp.Body)
if err != nil {
return metadataResp, fmt.Errorf("reading metadata uri response body: %s", err)
}
Expand All @@ -325,36 +333,37 @@ func getEcsTaskRegion(e ECSTaskMeta) (string, error) {
return a.Region, nil
}

func getEcsTasks(svc *ecs.ECS, clusterArn *string, family *string) ([]*string, error) {
var taskArns []*string
func getEcsTasks(svc *ecs.Client, clusterArn *string, family *string) ([]string, error) {
var taskArns []string
lti := ecs.ListTasksInput{
Cluster: clusterArn,
DesiredStatus: aws.String("RUNNING"),
DesiredStatus: ecstypes.DesiredStatusRunning,
}
if *family != "" {
lti.Family = family
}

paginator := ecs.NewListTasksPaginator(svc, &lti)

pageNum := 0
err := svc.ListTasksPages(&lti, func(page *ecs.ListTasksOutput, lastPage bool) bool {
for paginator.HasMorePages() {
page, err := paginator.NextPage(context.TODO())
if err != nil {
return nil, fmt.Errorf("ListTasks failed: %w", err)
}
pageNum++
taskArns = append(taskArns, page.TaskArns...)
log.Printf("[DEBUG] discover-aws: Retrieved %d TaskArns from page %d", len(taskArns), pageNum)
return !lastPage // return false to exit page function
})

if err != nil {
return nil, fmt.Errorf("ListTasks failed: %s", err)
}

return taskArns, nil
}

func getEcsTaskIps(svc *ecs.ECS, clusterArn *string, taskArns []*string, tagKey *string, tagValue *string) ([]string, error) {
func getEcsTaskIps(svc *ecs.Client, clusterArn *string, taskArns []string, tagKey *string, tagValue *string) ([]string, error) {
// Describe all the tasks listed for this cluster
taskDescriptions, err := svc.DescribeTasks(&ecs.DescribeTasksInput{
taskDescriptions, err := svc.DescribeTasks(context.TODO(), &ecs.DescribeTasksInput{
Cluster: clusterArn,
Include: []*string{aws.String(ecs.TaskFieldTags)},
Include: []ecstypes.TaskField{ecstypes.TaskFieldTags},
Tasks: taskArns,
})

Expand All @@ -376,7 +385,7 @@ func getEcsTaskIps(svc *ecs.ECS, clusterArn *string, taskArns []*string, tagKey

if *taskDescription.DesiredStatus == "RUNNING" {
log.Printf("[INFO] discover-aws: Found Running Instance: %s", *taskDescription.TaskArn)
ip := getIpFromTaskDescription(taskDescription)
ip := getIpFromTaskDescription(&taskDescription)

if ip != nil {
log.Printf("[DEBUG] discover-aws: Found Private IP: %s", *ip)
Expand All @@ -394,7 +403,7 @@ func getEcsTaskIps(svc *ecs.ECS, clusterArn *string, taskArns []*string, tagKey
return ipList, nil
}

func getIpFromTaskDescription(taskDesc *ecs.Task) *string {
func getIpFromTaskDescription(taskDesc *ecstypes.Task) *string {
log.Printf("[DEBUG] discover-aws: Searching %d attachments for IPs", len(taskDesc.Attachments))
for _, attachment := range taskDesc.Attachments {

Expand Down
Loading

0 comments on commit b237a2f

Please sign in to comment.