From 1f8ac6014307420bda57719cce6f4d48f4d2870c Mon Sep 17 00:00:00 2001 From: Alice Hubenko <ahubenko@redhat.com> Date: Wed, 15 May 2024 15:02:46 -0700 Subject: [PATCH] Factored out the FindRouteTableForSubnet command --- cmd/network/verification.go | 38 +------------------ pkg/utils/network.go | 76 +++++++++++++++++++++++++++++++++++++ 2 files changed, 78 insertions(+), 36 deletions(-) diff --git a/cmd/network/verification.go b/cmd/network/verification.go index a798d6992..7adf2d146 100644 --- a/cmd/network/verification.go +++ b/cmd/network/verification.go @@ -463,45 +463,11 @@ func (e *EgressVerification) isSubnetPublic(ctx context.Context, subnetID string var routeTable string // Try and find a Route Table associated with the given subnet - describeRouteTablesOutput, err := e.awsClient.DescribeRouteTables(ctx, &ec2.DescribeRouteTablesInput{ - Filters: []types.Filter{ - { - Name: aws.String("association.subnet-id"), - Values: []string{subnetID}, - }, - }, - }) - if err != nil { - return false, fmt.Errorf("failed to describe route tables associated to subnet %s: %w", subnetID, err) - } - // If there are no associated RouteTables, then the subnet uses the default RoutTable for the VPC - if len(describeRouteTablesOutput.RouteTables) == 0 { - // Get the VPC ID for the subnet - describeSubnetOutput, err := e.awsClient.DescribeSubnets(ctx, &ec2.DescribeSubnetsInput{ - SubnetIds: []string{subnetID}, - }) - if err != nil { - return false, err - } - if len(describeSubnetOutput.Subnets) == 0 { - return false, fmt.Errorf("no subnets returned for subnet id %v", subnetID) - } - - vpcID := *describeSubnetOutput.Subnets[0].VpcId - - // Set the route table to the default for the VPC - routeTable, err = e.findDefaultRouteTableForVPC(ctx, vpcID) - if err != nil { - return false, err - } - } else { - // Set the route table to the one associated with the subnet - routeTable = *describeRouteTablesOutput.RouteTables[0].RouteTableId - } + routeTable, err := utils.FindRouteTableForSubnetForVerification(e.awsClient, subnetID) // Check that the RouteTable for the subnet has a default route to 0.0.0.0/0 - describeRouteTablesOutput, err = e.awsClient.DescribeRouteTables(ctx, &ec2.DescribeRouteTablesInput{ + describeRouteTablesOutput, err := e.awsClient.DescribeRouteTables(ctx, &ec2.DescribeRouteTablesInput{ RouteTableIds: []string{routeTable}, }) if err != nil { diff --git a/pkg/utils/network.go b/pkg/utils/network.go index e5514ebe6..49fff1657 100644 --- a/pkg/utils/network.go +++ b/pkg/utils/network.go @@ -1,6 +1,7 @@ package utils import ( + "context" "fmt" awsSdk "github.com/aws/aws-sdk-go-v2/aws" @@ -9,6 +10,12 @@ import ( "github.com/openshift/osdctl/pkg/provider/aws" ) +type verificationAWSClient interface { + DescribeSubnets(ctx context.Context, params *ec2.DescribeSubnetsInput, optFns ...func(options *ec2.Options)) (*ec2.DescribeSubnetsOutput, error) + DescribeSecurityGroups(ctx context.Context, params *ec2.DescribeSecurityGroupsInput, optFns ...func(options *ec2.Options)) (*ec2.DescribeSecurityGroupsOutput, error) + DescribeRouteTables(ctx context.Context, params *ec2.DescribeRouteTablesInput, optFns ...func(options *ec2.Options)) (*ec2.DescribeRouteTablesOutput, error) +} + // Try and find a Route Table associated with the given subnet func FindRouteTableForSubnet(awsClient aws.Client, subnetID string) (string, error) { @@ -77,3 +84,72 @@ func findDefaultRouteTableForVPC(awsClient aws.Client, vpcID string) (string, er return "", fmt.Errorf("no default route table found for vpc: %s", vpcID) } + +// Try and find a Route Table associated with the given subnet for Egress Verification + +func FindRouteTableForSubnetForVerification(verificationAwsClient verificationAWSClient, subnetID string) (string, error) { + + var routeTable string + describeRouteTablesOutput, err := verificationAwsClient.DescribeRouteTables(context.TODO(), &ec2.DescribeRouteTablesInput{ + Filters: []types.Filter{ + { + Name: awsSdk.String("association.subnet-id"), + Values: []string{subnetID}, + }, + }, + }) + if err != nil { + return "", fmt.Errorf("failed to describe route tables associated to subnet %s: %w", subnetID, err) + } + + // If there are no associated RouteTables, then the subnet uses the default RoutTable for the VPC + if len(describeRouteTablesOutput.RouteTables) == 0 { + // Get the VPC ID for the subnet + describeSubnetOutput, err := verificationAwsClient.DescribeSubnets(context.TODO(), &ec2.DescribeSubnetsInput{ + SubnetIds: []string{subnetID}, + }) + if err != nil { + return "", err + } + if len(describeSubnetOutput.Subnets) == 0 { + return "", fmt.Errorf("no subnets returned for subnet id %v", subnetID) + } + + vpcID := *describeSubnetOutput.Subnets[0].VpcId + + // Set the route table to the default for the VPC + routeTable, err = findDefaultRouteTableForVPCForVerification(verificationAwsClient, vpcID) + if err != nil { + return "", err + } + } else { + // Set the route table to the one associated with the subnet + routeTable = *describeRouteTablesOutput.RouteTables[0].RouteTableId + } + return routeTable, err +} + +// findDefaultRouteTableForVPC returns the AWS Route Table ID of the VPC's default Route Table +func findDefaultRouteTableForVPCForVerification(awsClient verificationAWSClient, vpcID string) (string, error) { + describeRouteTablesOutput, err := awsClient.DescribeRouteTables(context.TODO(), &ec2.DescribeRouteTablesInput{ + Filters: []types.Filter{ + { + Name: awsSdk.String("vpc-id"), + Values: []string{vpcID}, + }, + }, + }) + if err != nil { + return "", fmt.Errorf("failed to describe route tables associated with vpc %s: %w", vpcID, err) + } + + for _, rt := range describeRouteTablesOutput.RouteTables { + for _, assoc := range rt.Associations { + if *assoc.Main { + return *rt.RouteTableId, nil + } + } + } + + return "", fmt.Errorf("no default route table found for vpc: %s", vpcID) +}