Skip to content

Commit

Permalink
Refactor fetching retain ports when reconciling SecurityGroup
Browse files Browse the repository at this point in the history
  • Loading branch information
zarvd committed Jun 17, 2024
1 parent 554d800 commit 71bff3b
Show file tree
Hide file tree
Showing 15 changed files with 743 additions and 473 deletions.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -24,16 +24,20 @@ import (
"github.com/Azure/azure-sdk-for-go/services/network/mgmt/2022-07-01/network"
"k8s.io/utils/ptr"

"sigs.k8s.io/cloud-provider-azure/internal/testutil"
"sigs.k8s.io/cloud-provider-azure/pkg/consts"
"sigs.k8s.io/cloud-provider-azure/pkg/provider/loadbalancer/fnutil"
"sigs.k8s.io/cloud-provider-azure/pkg/provider/loadbalancer/iputil"
"sigs.k8s.io/cloud-provider-azure/pkg/provider/loadbalancer/securitygroup"
"sigs.k8s.io/cloud-provider-azure/pkg/provider/loadbalancer/testutil"
)

// NoiseSecurityRules returns N non cloud-provider-specific security rules.
// It's not random, but it's good enough for testing.
func (f *AzureFixture) NoiseSecurityRules(nRules int) []network.SecurityRule {
func (f *AzureFixture) NoiseSecurityRules() []network.SecurityRule {
return f.NNoiseSecurityRules(3)
}

func (f *AzureFixture) NNoiseSecurityRules(nRules int) []network.SecurityRule {
var (
rv = make([]network.SecurityRule, 0, nRules)
protocolByID = func(id int) network.SecurityRuleProtocol {
Expand Down
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,11 @@ func (f *KubernetesServiceFixture) WithAllowedServiceTags(parts ...string) *Kube
return f
}

func (f *KubernetesServiceFixture) WithDisableFloatingIP() *KubernetesServiceFixture {
f.svc.Annotations[consts.ServiceAnnotationDisableLoadBalancerFloatingIP] = "true"
return f
}

func (f *KubernetesServiceFixture) WithLoadBalancerSourceRanges(parts ...string) *KubernetesServiceFixture {
f.svc.Spec.LoadBalancerSourceRanges = parts
return f
Expand Down
155 changes: 19 additions & 136 deletions pkg/provider/azure_loadbalancer.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@ import (
"unicode"

"github.com/Azure/azure-sdk-for-go/services/network/mgmt/2022-07-01/network"
"github.com/Azure/go-autorest/autorest/azure"
v1 "k8s.io/api/core/v1"
apierrors "k8s.io/apimachinery/pkg/api/errors"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
Expand All @@ -47,7 +46,6 @@ import (
"sigs.k8s.io/cloud-provider-azure/pkg/consts"
"sigs.k8s.io/cloud-provider-azure/pkg/metrics"
"sigs.k8s.io/cloud-provider-azure/pkg/provider/loadbalancer"
"sigs.k8s.io/cloud-provider-azure/pkg/provider/loadbalancer/fnutil"
"sigs.k8s.io/cloud-provider-azure/pkg/provider/loadbalancer/iputil"
"sigs.k8s.io/cloud-provider-azure/pkg/retry"
utilsets "sigs.k8s.io/cloud-provider-azure/pkg/util/sets"
Expand Down Expand Up @@ -144,7 +142,7 @@ func (az *Cloud) reconcileService(_ context.Context, clusterName string, service

serviceIPs := lbIPsPrimaryPIPs
klog.V(2).Infof("reconcileService: reconciling security group for service %q with IPs %q, wantLb = true", serviceName, serviceIPs)
if _, err := az.reconcileSecurityGroup(clusterName, service, ptr.Deref(lb.Name, ""), fipConfigs, serviceIPs, true /* wantLb */); err != nil {
if _, err := az.reconcileSecurityGroup(clusterName, service, ptr.Deref(lb.Name, ""), serviceIPs, true /* wantLb */); err != nil {
klog.Errorf("reconcileSecurityGroup(%s) failed: %#v", serviceName, err)
return nil, err
}
Expand Down Expand Up @@ -324,13 +322,7 @@ func (az *Cloud) EnsureLoadBalancerDeleted(_ context.Context, clusterName string
serviceIPsToCleanup := lbIPsPrimaryPIPs
klog.V(2).Infof("EnsureLoadBalancerDeleted: reconciling security group for service %q with IPs %q, wantLb = false", serviceName, serviceIPsToCleanup)

_, _, fipConfigs, err := az.getServiceLoadBalancerStatus(service, lb)
if err != nil {
klog.Errorf("EnsureLoadBalancerDeleted: getServiceLoadBalancerStatus(%s) failed: %v", serviceName, err)
return err
}

_, err = az.reconcileSecurityGroup(clusterName, service, ptr.Deref(lb.Name, ""), fipConfigs, serviceIPsToCleanup, false /* wantLb */)
_, err = az.reconcileSecurityGroup(clusterName, service, ptr.Deref(lb.Name, ""), serviceIPsToCleanup, false /* wantLb */)
if err != nil {
return err
}
Expand Down Expand Up @@ -2808,108 +2800,11 @@ func (az *Cloud) getExpectedHAModeLoadBalancingRuleProperties(
return props, nil
}

func (az *Cloud) listServicesByPublicIPs(pips []network.PublicIPAddress) ([]*v1.Service, error) {
logger := klog.Background().WithName("listServicesByPublicIPs").WithValues("num-pips", len(pips))
var (
rv []*v1.Service
ips []string
)

for _, pip := range pips {
if pip.ID == nil { // FIXME: it should not be nil
continue
}
resourceID, err := azure.ParseResourceID(*pip.ID)
if err != nil { // FIXME: it should never happen except for testing
continue
}
logger.V(4).Info("fetching public IPs", "pip-id", pip.ID)
pip, _, err := az.getPublicIPAddress(resourceID.ResourceGroup, resourceID.ResourceName, azcache.CacheReadTypeDefault)
if err != nil {
return nil, err
}
ips = append(ips, *pip.IPAddress)
}

logger = logger.WithValues("pips", ips)
allServices, err := az.serviceLister.List(labels.Everything())
if err != nil {
return nil, fmt.Errorf("list all services from lister: %w", err)
}
logger.V(4).Info("Listed all service from lister", "num-all-services", len(allServices))

rv = fnutil.Filter(func(svc *v1.Service) bool {
ingressIPs := fnutil.Map(func(ing v1.LoadBalancerIngress) string { return ing.IP }, svc.Status.LoadBalancer.Ingress)
ingressIPs = fnutil.Filter(func(ip string) bool { return ip != "" }, ingressIPs)
return len(fnutil.Intersection(ingressIPs, ips)) > 0
}, allServices)
logger.V(4).Info("Filtered services by public IPs", "num-target-services", len(rv))

return rv, nil
}

// listSharedIPPortMapping lists the shared IP port mapping for the service excluding the service itself.
// There are scenarios where multiple services share the same public IP,
// and in order to clean up the security rules, we need to know the port mapping of the shared IP.
func (az *Cloud) listSharedIPPortMapping(svc *v1.Service, publicIPs []network.PublicIPAddress) (map[network.SecurityRuleProtocol][]int32, error) {
var (
logger = klog.Background().WithName("listSharedIPPortMapping").WithValues("service-name", svc.Name)
rv = make(map[network.SecurityRuleProtocol][]int32)
convertProtocol = func(protocol v1.Protocol) (network.SecurityRuleProtocol, error) {
switch protocol {
case v1.ProtocolTCP:
return network.SecurityRuleProtocolTCP, nil
case v1.ProtocolUDP:
return network.SecurityRuleProtocolUDP, nil
case v1.ProtocolSCTP:
return network.SecurityRuleProtocolAsterisk, nil
}
return "", fmt.Errorf("unsupported protocol %s", protocol)
}
)

services, err := az.listServicesByPublicIPs(publicIPs)
if err != nil {
logger.Error(err, "Failed to list services by public IPs")
return nil, err
}

for _, s := range services {
logger.V(4).Info("iterating service", "service", s.Name, "namespace", s.Namespace)
if svc.Namespace == s.Namespace && svc.Name == s.Name {
// skip the service itself
continue
}

for _, port := range s.Spec.Ports {
protocol, err := convertProtocol(port.Protocol)
if err != nil {
return nil, err
}

var p int32
if consts.IsK8sServiceDisableLoadBalancerFloatingIP(s) {
p = port.NodePort
} else {
p = port.Port
}
logger.V(4).Info("adding port mapping", "protocol", protocol, "port", p)

rv[protocol] = append(rv[protocol], p)
}
}

logger.V(4).Info("retain port mapping", "port-mapping", rv)

return rv, nil
}

// This reconciles the Network Security Group similar to how the LB is reconciled.
// This entails adding required, missing SecurityRules and removing stale rules.
func (az *Cloud) reconcileSecurityGroup(
clusterName string, service *v1.Service,
lbName string,
fipConfigs []*network.FrontendIPConfiguration,
lbIPs []string, wantLb bool,
) (*network.SecurityGroup, error) {
logger := klog.Background().WithName("reconcileSecurityGroup").
Expand All @@ -2919,17 +2814,12 @@ func (az *Cloud) reconcileSecurityGroup(
WithValues("delete-lb", !wantLb)
logger.V(2).Info("Starting")

ctx := klog.NewContext(context.Background(), logger)

if wantLb && len(lbIPs) == 0 {
return nil, fmt.Errorf("no load balancer IP for setting up security rules for service %s", service.Name)
}

var publicIPs []network.PublicIPAddress
for _, fipConfig := range fipConfigs {
if fipConfig.PublicIPAddress != nil {
publicIPs = append(publicIPs, *fipConfig.PublicIPAddress)
}
}

additionalIPs, err := loadbalancer.AdditionalPublicIPs(service)
if wantLb && err != nil {
return nil, fmt.Errorf("unable to get additional public IPs: %w", err)
Expand Down Expand Up @@ -3000,15 +2890,23 @@ func (az *Cloud) reconcileSecurityGroup(
backendIPv6Addresses, _ = iputil.ParseAddresses(backendIPv6List)
}

{
// Disassociate all IPs from the security group
dstIPv4Addresses := append(lbIPv4Addresses, backendIPv4Addresses...)
dstIPv4Addresses = append(dstIPv4Addresses, additionalIPv4Addresses...)
var (
dstIPv4Addresses = additionalIPv4Addresses
dstIPv6Addresses = additionalIPv6Addresses
)

dstIPv6Addresses := append(lbIPv6Addresses, backendIPv6Addresses...)
dstIPv6Addresses = append(dstIPv6Addresses, additionalIPv6Addresses...)
if disableFloatingIP {
// use the backend node IPs
dstIPv4Addresses = append(dstIPv4Addresses, backendIPv4Addresses...)
dstIPv6Addresses = append(dstIPv6Addresses, backendIPv6Addresses...)
} else {
// use the LoadBalancer IPs
dstIPv4Addresses = append(dstIPv4Addresses, lbIPv4Addresses...)
dstIPv6Addresses = append(dstIPv6Addresses, lbIPv6Addresses...)
}

retainPortRanges, err := az.listSharedIPPortMapping(service, publicIPs)
{
retainPortRanges, err := az.listSharedIPPortMapping(ctx, service, append(dstIPv4Addresses, dstIPv6Addresses...))
if err != nil {
logger.Error(err, "Failed to list retain port ranges")
return nil, err
Expand All @@ -3021,21 +2919,6 @@ func (az *Cloud) reconcileSecurityGroup(
}

if wantLb {
var (
dstIPv4Addresses = additionalIPv4Addresses
dstIPv6Addresses = additionalIPv6Addresses
)

if disableFloatingIP {
// use the backend node IPs
dstIPv4Addresses = append(dstIPv4Addresses, backendIPv4Addresses...)
dstIPv6Addresses = append(dstIPv6Addresses, backendIPv6Addresses...)
} else {
// use the LoadBalancer IPs
dstIPv4Addresses = append(dstIPv4Addresses, lbIPv4Addresses...)
dstIPv6Addresses = append(dstIPv6Addresses, lbIPv6Addresses...)
}

err := accessControl.PatchSecurityGroup(dstIPv4Addresses, dstIPv6Addresses)
if err != nil {
logger.Error(err, "Failed to patch security group")
Expand Down
109 changes: 109 additions & 0 deletions pkg/provider/azure_loadbalancer_accesscontrol.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
/*
Copyright 2024 The Kubernetes Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/

package provider

import (
"context"
"fmt"
"net/netip"

v1 "k8s.io/api/core/v1"
"k8s.io/apimachinery/pkg/labels"
"k8s.io/klog/v2"

"github.com/Azure/azure-sdk-for-go/services/network/mgmt/2022-07-01/network"

"sigs.k8s.io/cloud-provider-azure/pkg/consts"
"sigs.k8s.io/cloud-provider-azure/pkg/provider/loadbalancer"
"sigs.k8s.io/cloud-provider-azure/pkg/provider/loadbalancer/fnutil"
)

func filterServicesByIngressIPs(services []*v1.Service, ips []netip.Addr) []*v1.Service {
targetIPs := fnutil.Map(func(ip netip.Addr) string { return ip.String() }, ips)

return fnutil.Filter(func(svc *v1.Service) bool {

ingressIPs := fnutil.Map(func(ing v1.LoadBalancerIngress) string { return ing.IP }, svc.Status.LoadBalancer.Ingress)

ingressIPs = fnutil.Filter(func(ip string) bool { return ip != "" }, ingressIPs)

return len(fnutil.Intersection(ingressIPs, targetIPs)) > 0
}, services)
}

func filterServicesByDisableFloatingIP(services []*v1.Service) []*v1.Service {
return fnutil.Filter(func(svc *v1.Service) bool {
return consts.IsK8sServiceDisableLoadBalancerFloatingIP(svc)
}, services)
}

// listSharedIPPortMapping lists the shared IP port mapping for the service excluding the service itself.
// There are scenarios where multiple services share the same public IP,
// and in order to clean up the security rules, we need to know the port mapping of the shared IP.
func (az *Cloud) listSharedIPPortMapping(
ctx context.Context,
svc *v1.Service,
ingressIPs []netip.Addr,
) (map[network.SecurityRuleProtocol][]int32, error) {
var (
logger = klog.FromContext(ctx).WithName("listSharedIPPortMapping")
rv = make(map[network.SecurityRuleProtocol][]int32)
)

var services []*v1.Service
{
var err error
logger.Info("Listing all services")
services, err = az.serviceLister.List(labels.Everything())
if err != nil {
logger.Error(err, "Failed to list all services")
return nil, fmt.Errorf("list all services: %w", err)
}
logger.Info("Listed all services", "num-all-services", len(services))

// Filter services by ingress IPs or backend node pool IPs (when disable floating IP)
if consts.IsK8sServiceDisableLoadBalancerFloatingIP(svc) {
logger.Info("Filter service by disableFloatingIP")
services = filterServicesByDisableFloatingIP(services)
} else {
logger.Info("Filter service by external IPs")
services = filterServicesByIngressIPs(services, ingressIPs)
}
}
logger.Info("Filtered services", "num-filtered-services", len(services))

for _, s := range services {
logger.V(4).Info("iterating service", "service", s.Name, "namespace", s.Namespace)
if svc.Namespace == s.Namespace && svc.Name == s.Name {
// skip the service itself
continue
}

portsByProtocol, err := loadbalancer.SecurityRuleDestinationPortsByProtocol(s)
if err != nil {
return nil, fmt.Errorf("fetch security rule dst ports for %s: %w", s.Name, err)
}

for protocol, ports := range portsByProtocol {
rv[protocol] = append(rv[protocol], ports...)
}
}

logger.V(4).Info("retain port mapping", "port-mapping", rv)

return rv, nil
}
Loading

0 comments on commit 71bff3b

Please sign in to comment.