From 1f8ac6014307420bda57719cce6f4d48f4d2870c Mon Sep 17 00:00:00 2001
From: Alice Hubenko <ahubenko@redhat.com>
Date: Wed, 15 May 2024 15:02:46 -0700
Subject: [PATCH] Factored out the FindRouteTableForSubnet command

---
 cmd/network/verification.go | 38 +------------------
 pkg/utils/network.go        | 76 +++++++++++++++++++++++++++++++++++++
 2 files changed, 78 insertions(+), 36 deletions(-)

diff --git a/cmd/network/verification.go b/cmd/network/verification.go
index a798d6992..7adf2d146 100644
--- a/cmd/network/verification.go
+++ b/cmd/network/verification.go
@@ -463,45 +463,11 @@ func (e *EgressVerification) isSubnetPublic(ctx context.Context, subnetID string
 	var routeTable string
 
 	// Try and find a Route Table associated with the given subnet
-	describeRouteTablesOutput, err := e.awsClient.DescribeRouteTables(ctx, &ec2.DescribeRouteTablesInput{
-		Filters: []types.Filter{
-			{
-				Name:   aws.String("association.subnet-id"),
-				Values: []string{subnetID},
-			},
-		},
-	})
-	if err != nil {
-		return false, fmt.Errorf("failed to describe route tables associated to subnet %s: %w", subnetID, err)
-	}
 
-	// If there are no associated RouteTables, then the subnet uses the default RoutTable for the VPC
-	if len(describeRouteTablesOutput.RouteTables) == 0 {
-		// Get the VPC ID for the subnet
-		describeSubnetOutput, err := e.awsClient.DescribeSubnets(ctx, &ec2.DescribeSubnetsInput{
-			SubnetIds: []string{subnetID},
-		})
-		if err != nil {
-			return false, err
-		}
-		if len(describeSubnetOutput.Subnets) == 0 {
-			return false, fmt.Errorf("no subnets returned for subnet id %v", subnetID)
-		}
-
-		vpcID := *describeSubnetOutput.Subnets[0].VpcId
-
-		// Set the route table to the default for the VPC
-		routeTable, err = e.findDefaultRouteTableForVPC(ctx, vpcID)
-		if err != nil {
-			return false, err
-		}
-	} else {
-		// Set the route table to the one associated with the subnet
-		routeTable = *describeRouteTablesOutput.RouteTables[0].RouteTableId
-	}
+	routeTable, err := utils.FindRouteTableForSubnetForVerification(e.awsClient, subnetID)
 
 	// Check that the RouteTable for the subnet has a default route to 0.0.0.0/0
-	describeRouteTablesOutput, err = e.awsClient.DescribeRouteTables(ctx, &ec2.DescribeRouteTablesInput{
+	describeRouteTablesOutput, err := e.awsClient.DescribeRouteTables(ctx, &ec2.DescribeRouteTablesInput{
 		RouteTableIds: []string{routeTable},
 	})
 	if err != nil {
diff --git a/pkg/utils/network.go b/pkg/utils/network.go
index e5514ebe6..49fff1657 100644
--- a/pkg/utils/network.go
+++ b/pkg/utils/network.go
@@ -1,6 +1,7 @@
 package utils
 
 import (
+	"context"
 	"fmt"
 
 	awsSdk "github.com/aws/aws-sdk-go-v2/aws"
@@ -9,6 +10,12 @@ import (
 	"github.com/openshift/osdctl/pkg/provider/aws"
 )
 
+type verificationAWSClient interface {
+	DescribeSubnets(ctx context.Context, params *ec2.DescribeSubnetsInput, optFns ...func(options *ec2.Options)) (*ec2.DescribeSubnetsOutput, error)
+	DescribeSecurityGroups(ctx context.Context, params *ec2.DescribeSecurityGroupsInput, optFns ...func(options *ec2.Options)) (*ec2.DescribeSecurityGroupsOutput, error)
+	DescribeRouteTables(ctx context.Context, params *ec2.DescribeRouteTablesInput, optFns ...func(options *ec2.Options)) (*ec2.DescribeRouteTablesOutput, error)
+}
+
 // Try and find a Route Table associated with the given subnet
 
 func FindRouteTableForSubnet(awsClient aws.Client, subnetID string) (string, error) {
@@ -77,3 +84,72 @@ func findDefaultRouteTableForVPC(awsClient aws.Client, vpcID string) (string, er
 
 	return "", fmt.Errorf("no default route table found for vpc: %s", vpcID)
 }
+
+// Try and find a Route Table associated with the given subnet for Egress Verification
+
+func FindRouteTableForSubnetForVerification(verificationAwsClient verificationAWSClient, subnetID string) (string, error) {
+
+	var routeTable string
+	describeRouteTablesOutput, err := verificationAwsClient.DescribeRouteTables(context.TODO(), &ec2.DescribeRouteTablesInput{
+		Filters: []types.Filter{
+			{
+				Name:   awsSdk.String("association.subnet-id"),
+				Values: []string{subnetID},
+			},
+		},
+	})
+	if err != nil {
+		return "", fmt.Errorf("failed to describe route tables associated to subnet %s: %w", subnetID, err)
+	}
+
+	// If there are no associated RouteTables, then the subnet uses the default RoutTable for the VPC
+	if len(describeRouteTablesOutput.RouteTables) == 0 {
+		// Get the VPC ID for the subnet
+		describeSubnetOutput, err := verificationAwsClient.DescribeSubnets(context.TODO(), &ec2.DescribeSubnetsInput{
+			SubnetIds: []string{subnetID},
+		})
+		if err != nil {
+			return "", err
+		}
+		if len(describeSubnetOutput.Subnets) == 0 {
+			return "", fmt.Errorf("no subnets returned for subnet id %v", subnetID)
+		}
+
+		vpcID := *describeSubnetOutput.Subnets[0].VpcId
+
+		// Set the route table to the default for the VPC
+		routeTable, err = findDefaultRouteTableForVPCForVerification(verificationAwsClient, vpcID)
+		if err != nil {
+			return "", err
+		}
+	} else {
+		// Set the route table to the one associated with the subnet
+		routeTable = *describeRouteTablesOutput.RouteTables[0].RouteTableId
+	}
+	return routeTable, err
+}
+
+// findDefaultRouteTableForVPC returns the AWS Route Table ID of the VPC's default Route Table
+func findDefaultRouteTableForVPCForVerification(awsClient verificationAWSClient, vpcID string) (string, error) {
+	describeRouteTablesOutput, err := awsClient.DescribeRouteTables(context.TODO(), &ec2.DescribeRouteTablesInput{
+		Filters: []types.Filter{
+			{
+				Name:   awsSdk.String("vpc-id"),
+				Values: []string{vpcID},
+			},
+		},
+	})
+	if err != nil {
+		return "", fmt.Errorf("failed to describe route tables associated with vpc %s: %w", vpcID, err)
+	}
+
+	for _, rt := range describeRouteTablesOutput.RouteTables {
+		for _, assoc := range rt.Associations {
+			if *assoc.Main {
+				return *rt.RouteTableId, nil
+			}
+		}
+	}
+
+	return "", fmt.Errorf("no default route table found for vpc: %s", vpcID)
+}