diff --git a/pkg/firewalls/firewalls_l4.go b/pkg/firewalls/firewalls_l4.go index d79bfcb574..8844f4d0c7 100644 --- a/pkg/firewalls/firewalls_l4.go +++ b/pkg/firewalls/firewalls_l4.go @@ -122,6 +122,10 @@ func firewallRuleEqual(a, b *compute.Firewall, skipDescription bool) bool { } } + if !utils.EqualStringSets(a.DestinationRanges, b.DestinationRanges) { + return false + } + if !utils.EqualStringSets(a.SourceRanges, b.SourceRanges) { return false } diff --git a/pkg/loadbalancers/l4_test.go b/pkg/loadbalancers/l4_test.go index eb4abc3732..10604783ee 100644 --- a/pkg/loadbalancers/l4_test.go +++ b/pkg/loadbalancers/l4_test.go @@ -1069,7 +1069,7 @@ func TestEnsureInternalFirewallPortRanges(t *testing.T) { Protocol: string(v1.ProtocolTCP), IP: "1.2.3.4", } - firewalls.EnsureL4FirewallRule(l.cloud, utils.ServiceKeyFunc(svc.Namespace, svc.Name), &fwrParams /*sharedRule = */, false) + err = firewalls.EnsureL4FirewallRule(l.cloud, utils.ServiceKeyFunc(svc.Namespace, svc.Name), &fwrParams /*sharedRule = */, false) if err != nil { t.Errorf("Unexpected error %v when ensuring firewall rule %s for svc %+v", err, fwName, svc) } diff --git a/pkg/loadbalancers/l4netlb_test.go b/pkg/loadbalancers/l4netlb_test.go index 4a7bcb0157..646383e61a 100644 --- a/pkg/loadbalancers/l4netlb_test.go +++ b/pkg/loadbalancers/l4netlb_test.go @@ -17,10 +17,14 @@ package loadbalancers import ( "fmt" - "k8s.io/ingress-gce/pkg/healthchecks" + "reflect" "strings" "testing" + "k8s.io/ingress-gce/pkg/firewalls" + "k8s.io/ingress-gce/pkg/flags" + "k8s.io/ingress-gce/pkg/healthchecks" + "github.com/GoogleCloudPlatform/k8s-cloud-provider/pkg/cloud" "github.com/GoogleCloudPlatform/k8s-cloud-provider/pkg/cloud/meta" "github.com/GoogleCloudPlatform/k8s-cloud-provider/pkg/cloud/mock" @@ -388,6 +392,59 @@ func TestMetricsForStandardNetworkTier(t *testing.T) { } } +func TestEnsureNetLBFirewallDestinations(t *testing.T) { + nodeNames := []string{"test-node-1"} + vals := gce.DefaultTestClusterValues() + fakeGCE := getFakeGCECloud(vals) + + svc := test.NewL4NetLBRBSService(8080) + namer := namer_util.NewL4Namer(kubeSystemUID, nil) + l4netlb := NewL4NetLB(svc, fakeGCE, meta.Regional, namer, record.NewFakeRecorder(100)) + l4netlb.l4HealthChecks = healthchecks.FakeL4(fakeGCE, &test.FakeRecorderSource{}) + + if _, err := test.CreateAndInsertNodes(l4netlb.cloud, nodeNames, vals.ZoneName); err != nil { + t.Errorf("Unexpected error when adding nodes %v", err) + } + + flags.F.EnablePinhole = true + fwName, _ := l4netlb.namer.L4Backend(l4netlb.Service.Namespace, l4netlb.Service.Name) + + fwrParams := firewalls.FirewallParams{ + Name: fwName, + SourceRanges: []string{"10.0.0.0/20"}, + DestinationRanges: []string{"20.0.0.0/20"}, + NodeNames: nodeNames, + Protocol: string(v1.ProtocolTCP), + IP: "1.2.3.4", + } + + err := firewalls.EnsureL4FirewallRule(l4netlb.cloud, utils.ServiceKeyFunc(svc.Namespace, svc.Name), &fwrParams /*sharedRule = */, false) + if err != nil { + t.Errorf("Unexpected error %v when ensuring firewall rule %s for svc %+v", err, fwName, svc) + } + existingFirewall, err := l4netlb.cloud.GetFirewall(fwName) + if err != nil || existingFirewall == nil || len(existingFirewall.Allowed) == 0 { + t.Errorf("Unexpected error %v when looking up firewall %s, Got firewall %+v", err, fwName, existingFirewall) + } + oldDestinationRanges := existingFirewall.DestinationRanges + + fwrParams.DestinationRanges = []string{"30.0.0.0/20"} + err = firewalls.EnsureL4FirewallRule(l4netlb.cloud, utils.ServiceKeyFunc(svc.Namespace, svc.Name), &fwrParams /*sharedRule = */, false) + if err != nil { + t.Errorf("Unexpected error %v when ensuring firewall rule %s for svc %+v", err, fwName, svc) + } + + updatedFirewall, err := l4netlb.cloud.GetFirewall(fwName) + if err != nil || updatedFirewall == nil || len(updatedFirewall.Allowed) == 0 { + t.Errorf("Unexpected error %v when looking up firewall %s, Got firewall %+v", err, fwName, updatedFirewall) + } + + if reflect.DeepEqual(oldDestinationRanges, updatedFirewall.DestinationRanges) { + t.Errorf("DestinationRanges is not udpated. oldDestinationRanges:%v, updatedFirewall.DestinationRanges:%v", oldDestinationRanges, updatedFirewall.DestinationRanges) + } + +} + func createUserStaticIPInStandardTier(fakeGCE *gce.Cloud, region string) { fakeGCE.Compute().(*cloud.MockGCE).MockAddresses.InsertHook = mock.InsertAddressHook fakeGCE.Compute().(*cloud.MockGCE).MockAlphaAddresses.X = mock.AddressAttributes{}