Skip to content

Commit

Permalink
Factored out the FindRouteTableForSubnet command
Browse files Browse the repository at this point in the history
  • Loading branch information
aliceh committed May 15, 2024
1 parent f164c15 commit 1f8ac60
Show file tree
Hide file tree
Showing 2 changed files with 78 additions and 36 deletions.
38 changes: 2 additions & 36 deletions cmd/network/verification.go
Original file line number Diff line number Diff line change
Expand Up @@ -463,45 +463,11 @@ func (e *EgressVerification) isSubnetPublic(ctx context.Context, subnetID string
var routeTable string

// Try and find a Route Table associated with the given subnet
describeRouteTablesOutput, err := e.awsClient.DescribeRouteTables(ctx, &ec2.DescribeRouteTablesInput{
Filters: []types.Filter{
{
Name: aws.String("association.subnet-id"),
Values: []string{subnetID},
},
},
})
if err != nil {
return false, fmt.Errorf("failed to describe route tables associated to subnet %s: %w", subnetID, err)
}

// If there are no associated RouteTables, then the subnet uses the default RoutTable for the VPC
if len(describeRouteTablesOutput.RouteTables) == 0 {
// Get the VPC ID for the subnet
describeSubnetOutput, err := e.awsClient.DescribeSubnets(ctx, &ec2.DescribeSubnetsInput{
SubnetIds: []string{subnetID},
})
if err != nil {
return false, err
}
if len(describeSubnetOutput.Subnets) == 0 {
return false, fmt.Errorf("no subnets returned for subnet id %v", subnetID)
}

vpcID := *describeSubnetOutput.Subnets[0].VpcId

// Set the route table to the default for the VPC
routeTable, err = e.findDefaultRouteTableForVPC(ctx, vpcID)
if err != nil {
return false, err
}
} else {
// Set the route table to the one associated with the subnet
routeTable = *describeRouteTablesOutput.RouteTables[0].RouteTableId
}
routeTable, err := utils.FindRouteTableForSubnetForVerification(e.awsClient, subnetID)

// Check that the RouteTable for the subnet has a default route to 0.0.0.0/0
describeRouteTablesOutput, err = e.awsClient.DescribeRouteTables(ctx, &ec2.DescribeRouteTablesInput{
describeRouteTablesOutput, err := e.awsClient.DescribeRouteTables(ctx, &ec2.DescribeRouteTablesInput{
RouteTableIds: []string{routeTable},
})
if err != nil {
Expand Down
76 changes: 76 additions & 0 deletions pkg/utils/network.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package utils

import (
"context"
"fmt"

awsSdk "github.com/aws/aws-sdk-go-v2/aws"
Expand All @@ -9,6 +10,12 @@ import (
"github.com/openshift/osdctl/pkg/provider/aws"
)

type verificationAWSClient interface {
DescribeSubnets(ctx context.Context, params *ec2.DescribeSubnetsInput, optFns ...func(options *ec2.Options)) (*ec2.DescribeSubnetsOutput, error)
DescribeSecurityGroups(ctx context.Context, params *ec2.DescribeSecurityGroupsInput, optFns ...func(options *ec2.Options)) (*ec2.DescribeSecurityGroupsOutput, error)
DescribeRouteTables(ctx context.Context, params *ec2.DescribeRouteTablesInput, optFns ...func(options *ec2.Options)) (*ec2.DescribeRouteTablesOutput, error)
}

// Try and find a Route Table associated with the given subnet

func FindRouteTableForSubnet(awsClient aws.Client, subnetID string) (string, error) {
Expand Down Expand Up @@ -77,3 +84,72 @@ func findDefaultRouteTableForVPC(awsClient aws.Client, vpcID string) (string, er

return "", fmt.Errorf("no default route table found for vpc: %s", vpcID)
}

// Try and find a Route Table associated with the given subnet for Egress Verification

func FindRouteTableForSubnetForVerification(verificationAwsClient verificationAWSClient, subnetID string) (string, error) {

var routeTable string
describeRouteTablesOutput, err := verificationAwsClient.DescribeRouteTables(context.TODO(), &ec2.DescribeRouteTablesInput{
Filters: []types.Filter{
{
Name: awsSdk.String("association.subnet-id"),
Values: []string{subnetID},
},
},
})
if err != nil {
return "", fmt.Errorf("failed to describe route tables associated to subnet %s: %w", subnetID, err)
}

// If there are no associated RouteTables, then the subnet uses the default RoutTable for the VPC
if len(describeRouteTablesOutput.RouteTables) == 0 {
// Get the VPC ID for the subnet
describeSubnetOutput, err := verificationAwsClient.DescribeSubnets(context.TODO(), &ec2.DescribeSubnetsInput{
SubnetIds: []string{subnetID},
})
if err != nil {
return "", err
}
if len(describeSubnetOutput.Subnets) == 0 {
return "", fmt.Errorf("no subnets returned for subnet id %v", subnetID)
}

vpcID := *describeSubnetOutput.Subnets[0].VpcId

// Set the route table to the default for the VPC
routeTable, err = findDefaultRouteTableForVPCForVerification(verificationAwsClient, vpcID)
if err != nil {
return "", err
}
} else {
// Set the route table to the one associated with the subnet
routeTable = *describeRouteTablesOutput.RouteTables[0].RouteTableId
}
return routeTable, err
}

// findDefaultRouteTableForVPC returns the AWS Route Table ID of the VPC's default Route Table
func findDefaultRouteTableForVPCForVerification(awsClient verificationAWSClient, vpcID string) (string, error) {
describeRouteTablesOutput, err := awsClient.DescribeRouteTables(context.TODO(), &ec2.DescribeRouteTablesInput{
Filters: []types.Filter{
{
Name: awsSdk.String("vpc-id"),
Values: []string{vpcID},
},
},
})
if err != nil {
return "", fmt.Errorf("failed to describe route tables associated with vpc %s: %w", vpcID, err)
}

for _, rt := range describeRouteTablesOutput.RouteTables {
for _, assoc := range rt.Associations {
if *assoc.Main {
return *rt.RouteTableId, nil
}
}
}

return "", fmt.Errorf("no default route table found for vpc: %s", vpcID)
}

0 comments on commit 1f8ac60

Please sign in to comment.