Skip to content

Commit

Permalink
Fix generating NSG rules while using shared BYO public IP
Browse files Browse the repository at this point in the history
  • Loading branch information
zarvd committed May 16, 2024
1 parent 539b23b commit d091723
Show file tree
Hide file tree
Showing 7 changed files with 204 additions and 38 deletions.
19 changes: 14 additions & 5 deletions pkg/provider/azure_fakes.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,14 @@ package provider
import (
"context"

"sigs.k8s.io/cloud-provider-azure/pkg/azclient"
"sigs.k8s.io/cloud-provider-azure/pkg/azclient/mock_azclient"
"sigs.k8s.io/cloud-provider-azure/pkg/provider/config"

"go.uber.org/mock/gomock"

"k8s.io/apimachinery/pkg/util/wait"
"k8s.io/client-go/informers"
"k8s.io/client-go/kubernetes/fake"
"k8s.io/client-go/tools/record"

"sigs.k8s.io/cloud-provider-azure/pkg/azclient"
"sigs.k8s.io/cloud-provider-azure/pkg/azclient/mock_azclient"
"sigs.k8s.io/cloud-provider-azure/pkg/azureclients/diskclient/mockdiskclient"
"sigs.k8s.io/cloud-provider-azure/pkg/azureclients/interfaceclient/mockinterfaceclient"
"sigs.k8s.io/cloud-provider-azure/pkg/azureclients/loadbalancerclient/mockloadbalancerclient"
Expand All @@ -41,6 +41,7 @@ import (
"sigs.k8s.io/cloud-provider-azure/pkg/azureclients/vmssclient/mockvmssclient"
"sigs.k8s.io/cloud-provider-azure/pkg/azureclients/vmssvmclient/mockvmssvmclient"
"sigs.k8s.io/cloud-provider-azure/pkg/consts"
"sigs.k8s.io/cloud-provider-azure/pkg/provider/config"
utilsets "sigs.k8s.io/cloud-provider-azure/pkg/util/sets"
)

Expand Down Expand Up @@ -132,6 +133,14 @@ func GetTestCloud(ctrl *gomock.Controller) (az *Cloud) {

az.regionZonesMap = map[string][]string{az.Location: {"1", "2", "3"}}

{
kubeClient := fake.NewSimpleClientset() // FIXME: inject kubeClient
informerFactory := informers.NewSharedInformerFactory(kubeClient, 0)
az.serviceLister = informerFactory.Core().V1().Services().Lister()
informerFactory.Start(wait.NeverStop)
informerFactory.WaitForCacheSync(wait.NeverStop)
}

return az
}

Expand Down
49 changes: 17 additions & 32 deletions pkg/provider/azure_loadbalancer.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@ import (

"github.com/Azure/azure-sdk-for-go/services/network/mgmt/2022-07-01/network"
"github.com/Azure/go-autorest/autorest/azure"

v1 "k8s.io/api/core/v1"
apierrors "k8s.io/apimachinery/pkg/api/errors"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
Expand All @@ -48,6 +47,7 @@ import (
"sigs.k8s.io/cloud-provider-azure/pkg/consts"
"sigs.k8s.io/cloud-provider-azure/pkg/metrics"
"sigs.k8s.io/cloud-provider-azure/pkg/provider/loadbalancer"
"sigs.k8s.io/cloud-provider-azure/pkg/provider/loadbalancer/fnutil"
"sigs.k8s.io/cloud-provider-azure/pkg/provider/loadbalancer/iputil"
"sigs.k8s.io/cloud-provider-azure/pkg/retry"
utilsets "sigs.k8s.io/cloud-provider-azure/pkg/util/sets"
Expand Down Expand Up @@ -2809,10 +2809,10 @@ func (az *Cloud) getExpectedHAModeLoadBalancingRuleProperties(
}

func (az *Cloud) listServicesByPublicIPs(pips []network.PublicIPAddress) ([]*v1.Service, error) {
logger := klog.Background().WithName("listServicesByPublicIPs")
logger := klog.Background().WithName("listServicesByPublicIPs").WithValues("num-pips", len(pips))
var (
svcNames []string
rv []*v1.Service
rv []*v1.Service
ips []string
)

for _, pip := range pips {
Expand All @@ -2825,40 +2825,25 @@ func (az *Cloud) listServicesByPublicIPs(pips []network.PublicIPAddress) ([]*v1.
}
logger.V(4).Info("fetching public IPs", "pip-id", pip.ID)
pip, _, err := az.getPublicIPAddress(resourceID.ResourceGroup, resourceID.ResourceName, azcache.CacheReadTypeDefault)

if err != nil {
return nil, err
}

logger.V(4).Info("fetched public IP", "pip", pip)
v := getServiceFromPIPServiceTags(pip.Tags)
if v != "" {
parts := strings.Split(strings.TrimSpace(v), ",")
for _, p := range parts {
p = strings.TrimSpace(p)
if p == "" {
continue
}
svcNames = append(svcNames, p)
}
}
ips = append(ips, *pip.IPAddress)
}

for _, svcName := range svcNames {
parts := strings.Split(svcName, "/")
if len(parts) != 2 {
continue
}
ns, svcName := parts[0], parts[1]

logger.Info("fetching service from lister", "ns", ns, "service-name", svcName)
svc, err := az.serviceLister.Services(ns).Get(svcName)
if err != nil {
return nil, fmt.Errorf("get service error: %w", err)
}

rv = append(rv, svc)
logger = logger.WithValues("pips", ips)
allServices, err := az.serviceLister.List(labels.Everything())
if err != nil {
return nil, fmt.Errorf("list all services from lister: %w", err)
}
logger.V(4).Info("Listed all service from lister", "num-all-services", len(allServices))

rv = fnutil.Filter(func(svc *v1.Service) bool {
ingressIPs := fnutil.Map(func(ing v1.LoadBalancerIngress) string { return ing.IP }, svc.Status.LoadBalancer.Ingress)
ingressIPs = fnutil.Filter(func(ip string) bool { return ip != "" }, ingressIPs)
return len(fnutil.Intersection(ingressIPs, ips)) > 0
}, allServices)
logger.V(4).Info("Filtered services by public IPs", "num-target-services", len(rv))

return rv, nil
}
Expand Down
4 changes: 3 additions & 1 deletion pkg/provider/azure_loadbalancer_accesscontrol_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1891,6 +1891,7 @@ func TestCloud_reconcileSecurityGroup(t *testing.T) {
sharedIPSvc := k8sFx.Service().
WithNamespace("ns-02").
WithName("svc-02").
WithIngressIPs([]string{"200.200.0.1"}).
Build()

sharedIPSvc.Spec.Ports = []v1.ServicePort{
Expand Down Expand Up @@ -1925,14 +1926,15 @@ func TestCloud_reconcileSecurityGroup(t *testing.T) {
WithName("svc-01").
WithAllowedServiceTags(allowedServiceTag).
WithAllowedIPRanges(allowedRanges...).
WithIngressIPs([]string{"200.200.0.1"}).
Build()

kubeClient = fake.NewSimpleClientset(&sharedIPSvc, &svc)
informerFactory = informers.NewSharedInformerFactory(kubeClient, 0)
svcLister = informerFactory.Core().V1().Services().Lister()

pip = fx.Azure().PublicIPAddress("pip1").
WithTag(consts.ServiceTagKey, fmt.Sprintf("%s/%s,%s/%s", svc.Namespace, svc.Name, sharedIPSvc.Namespace, sharedIPSvc.Name)).
WithAddress("200.200.0.1").
Build()
frontendIPConfigurations = []*network.FrontendIPConfiguration{
{
Expand Down
10 changes: 10 additions & 0 deletions pkg/provider/loadbalancer/fnutil/slice.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,16 @@ func Map[T any, R any](f func(T) R, xs []T) []R {
return rv
}

func Filter[T any](f func(T) bool, xs []T) []T {
var rv []T
for _, x := range xs {
if f(x) {
rv = append(rv, x)
}
}
return rv
}

func RemoveIf[T any](f func(T) bool, xs []T) []T {
var rv []T
for _, x := range xs {
Expand Down
5 changes: 5 additions & 0 deletions pkg/provider/loadbalancer/testutil/fixture/azure_publicip.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,3 +51,8 @@ func (f *AzurePublicIPAddressFixture) WithTag(key, value string) *AzurePublicIPA
f.pip.Tags[key] = ptr.To(value)
return f
}

func (f *AzurePublicIPAddressFixture) WithAddress(address string) *AzurePublicIPAddressFixture {
f.pip.PublicIPAddressPropertiesFormat.IPAddress = ptr.To(address)
return f
}
8 changes: 8 additions & 0 deletions pkg/provider/loadbalancer/testutil/fixture/kubernetes.go
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,14 @@ func (f *KubernetesServiceFixture) UDPNodePorts() []int32 {
return rv
}

func (f *KubernetesServiceFixture) WithIngressIPs(ips []string) *KubernetesServiceFixture {
f.svc.Status.LoadBalancer.Ingress = make([]v1.LoadBalancerIngress, len(ips))
for i, ip := range ips {
f.svc.Status.LoadBalancer.Ingress[i].IP = ip
}
return f
}

func (f *KubernetesServiceFixture) Build() v1.Service {
return f.svc
}
147 changes: 147 additions & 0 deletions tests/e2e/network/network_security_group.go
Original file line number Diff line number Diff line change
Expand Up @@ -830,6 +830,153 @@ var _ = Describe("Network security group", Label(utils.TestSuiteLabelNSG), func(
})
})
})

When("creating 2 LoadBalancer services with shared BYO public IP", func() {
It("should add rules independently", func() {

const (
Deployment1Name = "app-01"
Deployment2Name = "app-02"

Service1Name = "svc-01"
Service2Name = "svc-02"
)

var (
app1Port int32 = 80
app2Port int32 = 81
replicas int32 = 2
ipv4PIPName, ipv6PIPName string
ipv4PIPs, ipv6PIPs []netip.Addr
)

By("Creating shared BYO public IP")
{
v4Enabled, v6Enabled := utils.IfIPFamiliesEnabled(azureClient.IPFamily)
if v4Enabled {
// FIXME: avoid duplicated get name with suffix
base := fmt.Sprintf("%s-pip", namespace.Name)
ip, cleanup := createPIP(azureClient, base, false)
ipv4PIPName = utils.GetNameWithSuffix(base, utils.Suffixes[false])
ipv4PIPs = append(ipv4PIPs, netip.MustParseAddr(ip))
DeferCleanup(cleanup)
}
if v6Enabled {
base := fmt.Sprintf("%s-pip", namespace.Name)
ip, cleanup := createPIP(azureClient, base, true)
ipv6PIPName = utils.GetNameWithSuffix(base, utils.Suffixes[true])
ipv6PIPs = append(ipv6PIPs, netip.MustParseAddr(ip))
DeferCleanup(cleanup)
}
logger.Info("Created BYO public IP", "v4-PIP", ipv4PIPs, "v6-PIP", ipv6PIPs, "v4-PIP-Name", ipv4PIPName, "v6-PIP-Name", ipv6PIPName)
}

deployment1 := createDeploymentManifest(Deployment1Name, map[string]string{
"app": Deployment1Name,
}, &app1Port, nil)
deployment1.Spec.Replicas = &replicas
_, err := k8sClient.AppsV1().Deployments(namespace.Name).Create(context.Background(), deployment1, metav1.CreateOptions{})
Expect(err).NotTo(HaveOccurred())

deployment2 := createDeploymentManifest(Deployment2Name, map[string]string{
"app": Deployment2Name,
}, &app2Port, nil)
deployment2.Spec.Replicas = &replicas
_, err = k8sClient.AppsV1().Deployments(namespace.Name).Create(context.Background(), deployment2, metav1.CreateOptions{})
Expect(err).NotTo(HaveOccurred())

By("Creating service 1", func() {
var (
labels = map[string]string{
"app": Deployment1Name,
}
annotations = map[string]string{
consts.ServiceAnnotationPIPNameDualStack[false]: ipv4PIPName,
consts.ServiceAnnotationPIPNameDualStack[true]: ipv6PIPName,
}
ports = []v1.ServicePort{{
Port: app1Port,
TargetPort: intstr.FromInt32(app1Port),
}}
)
rv := createAndExposeDefaultServiceWithAnnotation(k8sClient, azureClient.IPFamily, Service1Name, namespace.Name, labels, annotations, ports)
ipv4s, ipv6s := groupIPsByFamily(mustParseIPs(derefSliceOfStringPtr(rv)))
logger.Info("Created the first LoadBalancer service", "svc-name", Service1Name, "v4-IPs", ipv4s, "v6-IPs", ipv6s)
Expect(ipv4s).To(Equal(ipv4PIPs))
Expect(ipv6s).To(Equal(ipv6PIPs))
})

By("Creating service 2", func() {
var (
labels = map[string]string{
"app": Deployment2Name,
}
annotations = map[string]string{
consts.ServiceAnnotationPIPNameDualStack[false]: ipv4PIPName,
consts.ServiceAnnotationPIPNameDualStack[true]: ipv6PIPName,
}
ports = []v1.ServicePort{{
Port: app2Port,
TargetPort: intstr.FromInt32(app2Port),
}}
)

rv := createAndExposeDefaultServiceWithAnnotation(k8sClient, azureClient.IPFamily, Service2Name, namespace.Name, labels, annotations, ports)
ipv4s, ipv6s := groupIPsByFamily(mustParseIPs(derefSliceOfStringPtr(rv)))
logger.Info("Created the second LoadBalancer service", "svc-name", Service2Name, "v4-IPs", ipv4s, "v6-IPs", ipv6s)
Expect(ipv4s).To(Equal(ipv4PIPs))
Expect(ipv6s).To(Equal(ipv6PIPs))
})

var validator *SecurityGroupValidator
By("Getting the cluster security groups", func() {
rv, err := azureClient.GetClusterSecurityGroups()
Expect(err).NotTo(HaveOccurred())

validator = NewSecurityGroupValidator(rv)
})

By("Checking if the rule for allowing traffic for app 01", func() {
var (
expectedProtocol = aznetwork.SecurityRuleProtocolTCP
expectedDstPorts = []string{strconv.FormatInt(int64(app1Port), 10)}
)

By("Checking if the rule for allowing traffic from Internet exists")

if len(ipv4PIPs) > 0 {
Expect(
validator.HasExactAllowRule(expectedProtocol, []string{"Internet"}, ipv4PIPs, expectedDstPorts),
).To(BeTrue(), "Should have a rule for allowing IPv4 traffic from Internet")
}

if len(ipv6PIPs) > 0 {
Expect(
validator.HasExactAllowRule(expectedProtocol, []string{"Internet"}, ipv6PIPs, expectedDstPorts),
).To(BeTrue(), "Should have a rule for allowing IPv6 traffic from Internet")
}
})

By("Checking if the rule for allowing traffic for app 02", func() {
var (
expectedProtocol = aznetwork.SecurityRuleProtocolTCP
expectedDstPorts = []string{strconv.FormatInt(int64(app2Port), 10)}
)
By("Checking if the rule for allowing traffic from Internet exists")
if len(ipv4PIPs) > 0 {
Expect(
validator.HasExactAllowRule(expectedProtocol, []string{"Internet"}, ipv4PIPs, expectedDstPorts),
).To(BeTrue(), "Should have a rule for allowing IPv4 traffic from Internet")
}

if len(ipv6PIPs) > 0 {
Expect(
validator.HasExactAllowRule(expectedProtocol, []string{"Internet"}, ipv6PIPs, expectedDstPorts),
).To(BeTrue(), "Should have a rule for allowing IPv6 traffic from Internet")
}
})
})
})
})

type SecurityGroupValidator struct {
Expand Down

0 comments on commit d091723

Please sign in to comment.