Skip to content

Commit

Permalink
migrate nsg client to track2 one
Browse files Browse the repository at this point in the history
Signed-off-by: Fan Shang Xiang <shafan@microsoft.com>
  • Loading branch information
MartinForReal committed Sep 29, 2024
1 parent ffcf976 commit c2d6203
Show file tree
Hide file tree
Showing 23 changed files with 1,606 additions and 1,461 deletions.
20 changes: 10 additions & 10 deletions internal/testutil/assert.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,20 +21,20 @@ import (
"sort"
"testing"

"github.com/Azure/azure-sdk-for-go/services/network/mgmt/2022-07-01/network"
"github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/network/armnetwork/v6"
"github.com/stretchr/testify/assert"
)

// ExpectHasSecurityRules asserts the security group whether it has the given rules.
func ExpectHasSecurityRules(t *testing.T, sg *network.SecurityGroup, expected []network.SecurityRule, msgAndArgs ...any) {
func ExpectHasSecurityRules(t *testing.T, sg *armnetwork.SecurityGroup, expected []*armnetwork.SecurityRule, msgAndArgs ...any) {
t.Helper()

expectedRuleIndex := make(map[string]network.SecurityRule)
expectedRuleIndex := make(map[string]*armnetwork.SecurityRule)
for _, rule := range expected {
expectedRuleIndex[*rule.Name] = rule
}

for _, actual := range *sg.SecurityRules {
for _, actual := range sg.Properties.SecurityRules {
expected, found := expectedRuleIndex[*actual.Name]
if !found {
continue
Expand All @@ -48,21 +48,21 @@ func ExpectHasSecurityRules(t *testing.T, sg *network.SecurityGroup, expected []
}

// ExpectExactSecurityRules asserts the security group whether it has the exact same rules.
func ExpectExactSecurityRules(t *testing.T, sg *network.SecurityGroup, expected []network.SecurityRule, msgAndArgs ...any) {
func ExpectExactSecurityRules(t *testing.T, sg *armnetwork.SecurityGroup, expected []*armnetwork.SecurityRule, msgAndArgs ...any) {
t.Helper()

assert.NotNil(t, sg)
assert.NotNil(t, sg.SecurityGroupPropertiesFormat)
assert.NotNil(t, sg.SecurityGroupPropertiesFormat.SecurityRules)
assert.NotNil(t, sg.Properties)
assert.NotNil(t, sg.Properties.SecurityRules)

actual := *sg.SecurityRules
actual := sg.Properties.SecurityRules

// order insensitive
sort.Slice(actual, func(i, j int) bool {
return *actual[i].Priority < *actual[j].Priority
return *actual[i].Properties.Priority < *actual[j].Properties.Priority
})
sort.Slice(expected, func(i, j int) bool {
return *expected[i].Priority < *expected[j].Priority
return *expected[i].Properties.Priority < *expected[j].Properties.Priority
})

ExpectEqualInJSON(t, expected, actual, msgAndArgs...)
Expand Down
120 changes: 60 additions & 60 deletions internal/testutil/fixture/azure_securitygroup.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@ import (
"sort"
"strconv"

"github.com/Azure/azure-sdk-for-go/services/network/mgmt/2022-07-01/network"

"github.com/Azure/azure-sdk-for-go/sdk/azcore/to"
"github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/network/armnetwork/v6"
"k8s.io/utils/ptr"

"sigs.k8s.io/cloud-provider-azure/internal/testutil"
Expand All @@ -34,58 +34,58 @@ import (

// NoiseSecurityRules returns 3 non cloud-provider-specific security rules.
// Use NNoiseSecurityRules if you need more.
func (f *AzureFixture) NoiseSecurityRules() []network.SecurityRule {
func (f *AzureFixture) NoiseSecurityRules() []*armnetwork.SecurityRule {
return f.NNoiseSecurityRules(3)
}

// NNoiseSecurityRules returns N non cloud-provider-specific security rules.
// It's not random, but it's good enough for testing.
func (f *AzureFixture) NNoiseSecurityRules(nRules int) []network.SecurityRule {
func (f *AzureFixture) NNoiseSecurityRules(nRules int) []*armnetwork.SecurityRule {
var (
rv = make([]network.SecurityRule, 0, nRules)
protocolByID = func(id int) network.SecurityRuleProtocol {
rv = make([]*armnetwork.SecurityRule, 0, nRules)
protocolByID = func(id int) *armnetwork.SecurityRuleProtocol {
switch id % 3 {
case 0:
return network.SecurityRuleProtocolTCP
return to.Ptr(armnetwork.SecurityRuleProtocolTCP)
case 1:
return network.SecurityRuleProtocolUDP
return to.Ptr(armnetwork.SecurityRuleProtocolUDP)
default:
return network.SecurityRuleProtocolAsterisk
return to.Ptr(armnetwork.SecurityRuleProtocolAsterisk)
}
}
)

initPriority := int32(100)
for i := 0; i < nRules; i++ {
rule := network.SecurityRule{
rule := &armnetwork.SecurityRule{
Name: ptr.To(fmt.Sprintf("test-security-rule_%d", i)),
SecurityRulePropertiesFormat: &network.SecurityRulePropertiesFormat{
Properties: &armnetwork.SecurityRulePropertiesFormat{
Priority: ptr.To(initPriority),
Protocol: protocolByID(i),
Direction: network.SecurityRuleDirectionInbound,
Access: network.SecurityRuleAccessAllow,
SourceAddressPrefixes: ptr.To([]string{
Direction: to.Ptr(armnetwork.SecurityRuleDirectionInbound),
Access: to.Ptr(armnetwork.SecurityRuleAccessAllow),
SourceAddressPrefixes: to.SliceOfPtrs(
fmt.Sprintf("140.0.0.%d", i), // NOTE: keep the source IP / destination IP unique to LB ips.
fmt.Sprintf("130.0.50.%d", i),
}),
),
SourcePortRange: ptr.To("*"),
DestinationPortRanges: ptr.To([]string{
DestinationPortRanges: to.SliceOfPtrs(
fmt.Sprintf("4000%d", i),
fmt.Sprintf("5000%d", i),
}),
),
},
}

switch i % 3 {
case 0:
rule.DestinationAddressPrefixes = ptr.To([]string{
rule.Properties.DestinationAddressPrefixes = to.SliceOfPtrs(
fmt.Sprintf("222.111.0.%d", i),
fmt.Sprintf("200.0.50.%d", i),
})
)
case 1:
rule.DestinationAddressPrefix = ptr.To(fmt.Sprintf("222.111.0.%d", i))
rule.Properties.DestinationAddressPrefix = ptr.To(fmt.Sprintf("222.111.0.%d", i))
case 2:
rule.DestinationApplicationSecurityGroups = &[]network.ApplicationSecurityGroup{
rule.Properties.DestinationApplicationSecurityGroups = []*armnetwork.ApplicationSecurityGroup{
{
ID: ptr.To(fmt.Sprintf("/subscriptions/00000000-0000-0000-0000-000000000000/resourceGroups/the-rg/providers/Microsoft.Network/applicationSecurityGroups/the-asg-%d", i)),
},
Expand All @@ -105,17 +105,17 @@ func (f *AzureFixture) NNoiseSecurityRules(nRules int) []network.SecurityRule {

func (f *AzureFixture) SecurityGroup() *AzureSecurityGroupFixture {
return &AzureSecurityGroupFixture{
sg: &network.SecurityGroup{
sg: &armnetwork.SecurityGroup{
Name: ptr.To("nsg"),
SecurityGroupPropertiesFormat: &network.SecurityGroupPropertiesFormat{
SecurityRules: &[]network.SecurityRule{},
Properties: &armnetwork.SecurityGroupPropertiesFormat{
SecurityRules: []*armnetwork.SecurityRule{},
},
},
}
}

func (f *AzureFixture) AllowSecurityRule(
protocol network.SecurityRuleProtocol,
protocol armnetwork.SecurityRuleProtocol,
ipFamily iputil.Family,
srcPrefixes []string,
dstPorts []int32,
Expand All @@ -124,36 +124,36 @@ func (f *AzureFixture) AllowSecurityRule(
sort.Strings(dstPortRanges)

rv := &AzureAllowSecurityRuleFixture{
rule: &network.SecurityRule{
rule: &armnetwork.SecurityRule{
Name: ptr.To(securitygroup.GenerateAllowSecurityRuleName(protocol, ipFamily, srcPrefixes, dstPorts)),
SecurityRulePropertiesFormat: &network.SecurityRulePropertiesFormat{
Protocol: protocol,
Access: network.SecurityRuleAccessAllow,
Direction: network.SecurityRuleDirectionInbound,
Properties: &armnetwork.SecurityRulePropertiesFormat{
Protocol: to.Ptr(protocol),
Access: to.Ptr(armnetwork.SecurityRuleAccessAllow),
Direction: to.Ptr(armnetwork.SecurityRuleDirectionInbound),
SourcePortRange: ptr.To("*"),
DestinationPortRanges: ptr.To(dstPortRanges),
DestinationPortRanges: to.SliceOfPtrs(dstPortRanges...),
Priority: ptr.To(int32(consts.LoadBalancerMinimumPriority)),
},
},
}

if len(srcPrefixes) == 1 {
rv.rule.SourceAddressPrefix = ptr.To(srcPrefixes[0])
rv.rule.Properties.SourceAddressPrefix = ptr.To(srcPrefixes[0])
} else {
rv.rule.SourceAddressPrefixes = ptr.To(srcPrefixes)
rv.rule.Properties.SourceAddressPrefixes = to.SliceOfPtrs(srcPrefixes...)
}

return rv
}

func (f *AzureFixture) DenyAllSecurityRule(ipFamily iputil.Family) *AzureDenyAllSecurityRuleFixture {
return &AzureDenyAllSecurityRuleFixture{
rule: &network.SecurityRule{
rule: &armnetwork.SecurityRule{
Name: ptr.To(securitygroup.GenerateDenyAllSecurityRuleName(ipFamily)),
SecurityRulePropertiesFormat: &network.SecurityRulePropertiesFormat{
Protocol: network.SecurityRuleProtocolAsterisk,
Access: network.SecurityRuleAccessDeny,
Direction: network.SecurityRuleDirectionInbound,
Properties: &armnetwork.SecurityRulePropertiesFormat{
Protocol: to.Ptr(armnetwork.SecurityRuleProtocolAsterisk),
Access: to.Ptr(armnetwork.SecurityRuleAccessDeny),
Direction: to.Ptr(armnetwork.SecurityRuleDirectionInbound),
SourcePortRange: ptr.To("*"),
SourceAddressPrefix: ptr.To("*"),
DestinationPortRange: ptr.To("*"),
Expand All @@ -165,69 +165,69 @@ func (f *AzureFixture) DenyAllSecurityRule(ipFamily iputil.Family) *AzureDenyAll

// AzureSecurityGroupFixture is a fixture for an Azure security group.
type AzureSecurityGroupFixture struct {
sg *network.SecurityGroup
sg *armnetwork.SecurityGroup
}

func (f *AzureSecurityGroupFixture) WithRules(rules []network.SecurityRule) *AzureSecurityGroupFixture {
func (f *AzureSecurityGroupFixture) WithRules(rules []*armnetwork.SecurityRule) *AzureSecurityGroupFixture {
if rules == nil {
rules = []network.SecurityRule{}
rules = []*armnetwork.SecurityRule{}
}
clonedRules := testutil.CloneInJSON(rules) // keep the original one immutable
f.sg.SecurityRules = &clonedRules
f.sg.Properties.SecurityRules = clonedRules
return f
}

func (f *AzureSecurityGroupFixture) Build() network.SecurityGroup {
return *f.sg
func (f *AzureSecurityGroupFixture) Build() *armnetwork.SecurityGroup {
return f.sg
}

// AzureAllowSecurityRuleFixture is a fixture for an allow security rule.
type AzureAllowSecurityRuleFixture struct {
rule *network.SecurityRule
rule *armnetwork.SecurityRule
}

func (f *AzureAllowSecurityRuleFixture) WithPriority(p int32) *AzureAllowSecurityRuleFixture {
f.rule.Priority = ptr.To(p)
f.rule.Properties.Priority = ptr.To(p)
return f
}

func (f *AzureAllowSecurityRuleFixture) WithDestination(prefixes ...string) *AzureAllowSecurityRuleFixture {
if len(prefixes) == 1 {
f.rule.DestinationAddressPrefix = ptr.To(prefixes[0])
f.rule.DestinationAddressPrefixes = nil
f.rule.Properties.DestinationAddressPrefix = ptr.To(prefixes[0])
f.rule.Properties.DestinationAddressPrefixes = nil
} else {
f.rule.DestinationAddressPrefix = nil
f.rule.DestinationAddressPrefixes = ptr.To(securitygroup.NormalizeSecurityRuleAddressPrefixes(prefixes))
f.rule.Properties.DestinationAddressPrefix = nil
f.rule.Properties.DestinationAddressPrefixes = to.SliceOfPtrs(securitygroup.NormalizeSecurityRuleAddressPrefixes(prefixes)...)
}

return f
}

func (f *AzureAllowSecurityRuleFixture) Build() network.SecurityRule {
return *f.rule
func (f *AzureAllowSecurityRuleFixture) Build() *armnetwork.SecurityRule {
return f.rule
}

// AzureDenyAllSecurityRuleFixture is a fixture for a deny-all security rule.
type AzureDenyAllSecurityRuleFixture struct {
rule *network.SecurityRule
rule *armnetwork.SecurityRule
}

func (f *AzureDenyAllSecurityRuleFixture) WithPriority(p int32) *AzureDenyAllSecurityRuleFixture {
f.rule.Priority = ptr.To(p)
f.rule.Properties.Priority = ptr.To(p)
return f
}

func (f *AzureDenyAllSecurityRuleFixture) WithDestination(prefixes ...string) *AzureDenyAllSecurityRuleFixture {
if len(prefixes) == 1 {
f.rule.DestinationAddressPrefix = ptr.To(prefixes[0])
f.rule.DestinationAddressPrefixes = nil
f.rule.Properties.DestinationAddressPrefix = ptr.To(prefixes[0])
f.rule.Properties.DestinationAddressPrefixes = nil
} else {
f.rule.DestinationAddressPrefix = nil
f.rule.DestinationAddressPrefixes = ptr.To(securitygroup.NormalizeSecurityRuleAddressPrefixes(prefixes))
f.rule.Properties.DestinationAddressPrefix = nil
f.rule.Properties.DestinationAddressPrefixes = to.SliceOfPtrs(securitygroup.NormalizeSecurityRuleAddressPrefixes(prefixes)...)
}
return f
}

func (f *AzureDenyAllSecurityRuleFixture) Build() network.SecurityRule {
return *f.rule
func (f *AzureDenyAllSecurityRuleFixture) Build() *armnetwork.SecurityRule {
return f.rule
}
15 changes: 12 additions & 3 deletions pkg/provider/azure_fakes.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,14 +28,16 @@ import (

"sigs.k8s.io/cloud-provider-azure/pkg/azclient"
"sigs.k8s.io/cloud-provider-azure/pkg/azclient/mock_azclient"
"sigs.k8s.io/cloud-provider-azure/pkg/azclient/privatezoneclient/mock_privatezoneclient"
"sigs.k8s.io/cloud-provider-azure/pkg/azclient/securitygroupclient/mock_securitygroupclient"
"sigs.k8s.io/cloud-provider-azure/pkg/azclient/virtualnetworklinkclient/mock_virtualnetworklinkclient"
"sigs.k8s.io/cloud-provider-azure/pkg/azureclients/diskclient/mockdiskclient"
"sigs.k8s.io/cloud-provider-azure/pkg/azureclients/interfaceclient/mockinterfaceclient"
"sigs.k8s.io/cloud-provider-azure/pkg/azureclients/loadbalancerclient/mockloadbalancerclient"
"sigs.k8s.io/cloud-provider-azure/pkg/azureclients/privatelinkserviceclient/mockprivatelinkserviceclient"
"sigs.k8s.io/cloud-provider-azure/pkg/azureclients/publicipclient/mockpublicipclient"
"sigs.k8s.io/cloud-provider-azure/pkg/azureclients/routeclient/mockrouteclient"
"sigs.k8s.io/cloud-provider-azure/pkg/azureclients/routetableclient/mockroutetableclient"
"sigs.k8s.io/cloud-provider-azure/pkg/azureclients/securitygroupclient/mocksecuritygroupclient"
"sigs.k8s.io/cloud-provider-azure/pkg/azureclients/snapshotclient/mocksnapshotclient"
"sigs.k8s.io/cloud-provider-azure/pkg/azureclients/subnetclient/mocksubnetclient"
"sigs.k8s.io/cloud-provider-azure/pkg/azureclients/vmclient/mockvmclient"
Expand Down Expand Up @@ -116,13 +118,20 @@ func GetTestCloud(ctrl *gomock.Controller) (az *Cloud) {
az.PublicIPAddressesClient = mockpublicipclient.NewMockInterface(ctrl)
az.RoutesClient = mockrouteclient.NewMockInterface(ctrl)
az.RouteTablesClient = mockroutetableclient.NewMockInterface(ctrl)
az.SecurityGroupsClient = mocksecuritygroupclient.NewMockInterface(ctrl)
az.SubnetsClient = mocksubnetclient.NewMockInterface(ctrl)
az.VirtualMachineScaleSetsClient = mockvmssclient.NewMockInterface(ctrl)
az.VirtualMachineScaleSetVMsClient = mockvmssvmclient.NewMockInterface(ctrl)
az.VirtualMachinesClient = mockvmclient.NewMockInterface(ctrl)
az.PrivateLinkServiceClient = mockprivatelinkserviceclient.NewMockInterface(ctrl)
az.ComputeClientFactory = mock_azclient.NewMockClientFactory(ctrl)
clientFactory := mock_azclient.NewMockClientFactory(ctrl)
az.ComputeClientFactory = clientFactory
az.NetworkClientFactory = clientFactory
securtyGrouptrack2Client := mock_securitygroupclient.NewMockInterface(ctrl)
clientFactory.EXPECT().GetSecurityGroupClient().Return(securtyGrouptrack2Client).AnyTimes()
mockPrivateDNSClient := mock_privatezoneclient.NewMockInterface(ctrl)
clientFactory.EXPECT().GetPrivateZoneClient().Return(mockPrivateDNSClient).AnyTimes()
virtualNetworkLinkClient := mock_virtualnetworklinkclient.NewMockInterface(ctrl)
clientFactory.EXPECT().GetVirtualNetworkLinkClient().Return(virtualNetworkLinkClient).AnyTimes()
az.AuthProvider = &azclient.AuthProvider{
ComputeCredential: mock_azclient.NewMockTokenCredential(ctrl),
}
Expand Down
Loading

0 comments on commit c2d6203

Please sign in to comment.