Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Track2 sdk:migrate nsg client to track2 one #7155

Merged
merged 1 commit into from
Sep 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 10 additions & 10 deletions internal/testutil/assert.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,20 +21,20 @@ import (
"sort"
"testing"

"github.com/Azure/azure-sdk-for-go/services/network/mgmt/2022-07-01/network"
"github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/network/armnetwork/v6"
"github.com/stretchr/testify/assert"
)

// ExpectHasSecurityRules asserts the security group whether it has the given rules.
func ExpectHasSecurityRules(t *testing.T, sg *network.SecurityGroup, expected []network.SecurityRule, msgAndArgs ...any) {
func ExpectHasSecurityRules(t *testing.T, sg *armnetwork.SecurityGroup, expected []*armnetwork.SecurityRule, msgAndArgs ...any) {
t.Helper()

expectedRuleIndex := make(map[string]network.SecurityRule)
expectedRuleIndex := make(map[string]*armnetwork.SecurityRule)
for _, rule := range expected {
expectedRuleIndex[*rule.Name] = rule
}

for _, actual := range *sg.SecurityRules {
for _, actual := range sg.Properties.SecurityRules {
expected, found := expectedRuleIndex[*actual.Name]
if !found {
continue
Expand All @@ -48,21 +48,21 @@ func ExpectHasSecurityRules(t *testing.T, sg *network.SecurityGroup, expected []
}

// ExpectExactSecurityRules asserts the security group whether it has the exact same rules.
func ExpectExactSecurityRules(t *testing.T, sg *network.SecurityGroup, expected []network.SecurityRule, msgAndArgs ...any) {
func ExpectExactSecurityRules(t *testing.T, sg *armnetwork.SecurityGroup, expected []*armnetwork.SecurityRule, msgAndArgs ...any) {
t.Helper()

assert.NotNil(t, sg)
assert.NotNil(t, sg.SecurityGroupPropertiesFormat)
assert.NotNil(t, sg.SecurityGroupPropertiesFormat.SecurityRules)
assert.NotNil(t, sg.Properties)
assert.NotNil(t, sg.Properties.SecurityRules)

actual := *sg.SecurityRules
actual := sg.Properties.SecurityRules

// order insensitive
sort.Slice(actual, func(i, j int) bool {
return *actual[i].Priority < *actual[j].Priority
return *actual[i].Properties.Priority < *actual[j].Properties.Priority
})
sort.Slice(expected, func(i, j int) bool {
return *expected[i].Priority < *expected[j].Priority
return *expected[i].Properties.Priority < *expected[j].Properties.Priority
})

ExpectEqualInJSON(t, expected, actual, msgAndArgs...)
Expand Down
120 changes: 60 additions & 60 deletions internal/testutil/fixture/azure_securitygroup.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@ import (
"sort"
"strconv"

"github.com/Azure/azure-sdk-for-go/services/network/mgmt/2022-07-01/network"

"github.com/Azure/azure-sdk-for-go/sdk/azcore/to"
"github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/network/armnetwork/v6"
"k8s.io/utils/ptr"

"sigs.k8s.io/cloud-provider-azure/internal/testutil"
Expand All @@ -34,58 +34,58 @@ import (

// NoiseSecurityRules returns 3 non cloud-provider-specific security rules.
// Use NNoiseSecurityRules if you need more.
func (f *AzureFixture) NoiseSecurityRules() []network.SecurityRule {
func (f *AzureFixture) NoiseSecurityRules() []*armnetwork.SecurityRule {
return f.NNoiseSecurityRules(3)
}

// NNoiseSecurityRules returns N non cloud-provider-specific security rules.
// It's not random, but it's good enough for testing.
func (f *AzureFixture) NNoiseSecurityRules(nRules int) []network.SecurityRule {
func (f *AzureFixture) NNoiseSecurityRules(nRules int) []*armnetwork.SecurityRule {
var (
rv = make([]network.SecurityRule, 0, nRules)
protocolByID = func(id int) network.SecurityRuleProtocol {
rv = make([]*armnetwork.SecurityRule, 0, nRules)
protocolByID = func(id int) *armnetwork.SecurityRuleProtocol {
switch id % 3 {
case 0:
return network.SecurityRuleProtocolTCP
return to.Ptr(armnetwork.SecurityRuleProtocolTCP)
case 1:
return network.SecurityRuleProtocolUDP
return to.Ptr(armnetwork.SecurityRuleProtocolUDP)
default:
return network.SecurityRuleProtocolAsterisk
return to.Ptr(armnetwork.SecurityRuleProtocolAsterisk)
}
}
)

initPriority := int32(100)
for i := 0; i < nRules; i++ {
rule := network.SecurityRule{
rule := &armnetwork.SecurityRule{
Name: ptr.To(fmt.Sprintf("test-security-rule_%d", i)),
SecurityRulePropertiesFormat: &network.SecurityRulePropertiesFormat{
Properties: &armnetwork.SecurityRulePropertiesFormat{
Priority: ptr.To(initPriority),
Protocol: protocolByID(i),
Direction: network.SecurityRuleDirectionInbound,
Access: network.SecurityRuleAccessAllow,
SourceAddressPrefixes: ptr.To([]string{
Direction: to.Ptr(armnetwork.SecurityRuleDirectionInbound),
Access: to.Ptr(armnetwork.SecurityRuleAccessAllow),
SourceAddressPrefixes: to.SliceOfPtrs(
fmt.Sprintf("140.0.0.%d", i), // NOTE: keep the source IP / destination IP unique to LB ips.
fmt.Sprintf("130.0.50.%d", i),
}),
),
SourcePortRange: ptr.To("*"),
DestinationPortRanges: ptr.To([]string{
DestinationPortRanges: to.SliceOfPtrs(
fmt.Sprintf("4000%d", i),
fmt.Sprintf("5000%d", i),
}),
),
},
}

switch i % 3 {
case 0:
rule.DestinationAddressPrefixes = ptr.To([]string{
rule.Properties.DestinationAddressPrefixes = to.SliceOfPtrs(
fmt.Sprintf("222.111.0.%d", i),
fmt.Sprintf("200.0.50.%d", i),
})
)
case 1:
rule.DestinationAddressPrefix = ptr.To(fmt.Sprintf("222.111.0.%d", i))
rule.Properties.DestinationAddressPrefix = ptr.To(fmt.Sprintf("222.111.0.%d", i))
case 2:
rule.DestinationApplicationSecurityGroups = &[]network.ApplicationSecurityGroup{
rule.Properties.DestinationApplicationSecurityGroups = []*armnetwork.ApplicationSecurityGroup{
{
ID: ptr.To(fmt.Sprintf("/subscriptions/00000000-0000-0000-0000-000000000000/resourceGroups/the-rg/providers/Microsoft.Network/applicationSecurityGroups/the-asg-%d", i)),
},
Expand All @@ -105,17 +105,17 @@ func (f *AzureFixture) NNoiseSecurityRules(nRules int) []network.SecurityRule {

func (f *AzureFixture) SecurityGroup() *AzureSecurityGroupFixture {
return &AzureSecurityGroupFixture{
sg: &network.SecurityGroup{
sg: &armnetwork.SecurityGroup{
Name: ptr.To("nsg"),
SecurityGroupPropertiesFormat: &network.SecurityGroupPropertiesFormat{
SecurityRules: &[]network.SecurityRule{},
Properties: &armnetwork.SecurityGroupPropertiesFormat{
SecurityRules: []*armnetwork.SecurityRule{},
},
},
}
}

func (f *AzureFixture) AllowSecurityRule(
protocol network.SecurityRuleProtocol,
protocol armnetwork.SecurityRuleProtocol,
ipFamily iputil.Family,
srcPrefixes []string,
dstPorts []int32,
Expand All @@ -124,36 +124,36 @@ func (f *AzureFixture) AllowSecurityRule(
sort.Strings(dstPortRanges)

rv := &AzureAllowSecurityRuleFixture{
rule: &network.SecurityRule{
rule: &armnetwork.SecurityRule{
Name: ptr.To(securitygroup.GenerateAllowSecurityRuleName(protocol, ipFamily, srcPrefixes, dstPorts)),
SecurityRulePropertiesFormat: &network.SecurityRulePropertiesFormat{
Protocol: protocol,
Access: network.SecurityRuleAccessAllow,
Direction: network.SecurityRuleDirectionInbound,
Properties: &armnetwork.SecurityRulePropertiesFormat{
Protocol: to.Ptr(protocol),
Access: to.Ptr(armnetwork.SecurityRuleAccessAllow),
Direction: to.Ptr(armnetwork.SecurityRuleDirectionInbound),
SourcePortRange: ptr.To("*"),
DestinationPortRanges: ptr.To(dstPortRanges),
DestinationPortRanges: to.SliceOfPtrs(dstPortRanges...),
Priority: ptr.To(int32(consts.LoadBalancerMinimumPriority)),
},
},
}

if len(srcPrefixes) == 1 {
rv.rule.SourceAddressPrefix = ptr.To(srcPrefixes[0])
rv.rule.Properties.SourceAddressPrefix = ptr.To(srcPrefixes[0])
} else {
rv.rule.SourceAddressPrefixes = ptr.To(srcPrefixes)
rv.rule.Properties.SourceAddressPrefixes = to.SliceOfPtrs(srcPrefixes...)
}

return rv
}

func (f *AzureFixture) DenyAllSecurityRule(ipFamily iputil.Family) *AzureDenyAllSecurityRuleFixture {
return &AzureDenyAllSecurityRuleFixture{
rule: &network.SecurityRule{
rule: &armnetwork.SecurityRule{
Name: ptr.To(securitygroup.GenerateDenyAllSecurityRuleName(ipFamily)),
SecurityRulePropertiesFormat: &network.SecurityRulePropertiesFormat{
Protocol: network.SecurityRuleProtocolAsterisk,
Access: network.SecurityRuleAccessDeny,
Direction: network.SecurityRuleDirectionInbound,
Properties: &armnetwork.SecurityRulePropertiesFormat{
Protocol: to.Ptr(armnetwork.SecurityRuleProtocolAsterisk),
Access: to.Ptr(armnetwork.SecurityRuleAccessDeny),
Direction: to.Ptr(armnetwork.SecurityRuleDirectionInbound),
SourcePortRange: ptr.To("*"),
SourceAddressPrefix: ptr.To("*"),
DestinationPortRange: ptr.To("*"),
Expand All @@ -165,69 +165,69 @@ func (f *AzureFixture) DenyAllSecurityRule(ipFamily iputil.Family) *AzureDenyAll

// AzureSecurityGroupFixture is a fixture for an Azure security group.
type AzureSecurityGroupFixture struct {
sg *network.SecurityGroup
sg *armnetwork.SecurityGroup
}

func (f *AzureSecurityGroupFixture) WithRules(rules []network.SecurityRule) *AzureSecurityGroupFixture {
func (f *AzureSecurityGroupFixture) WithRules(rules []*armnetwork.SecurityRule) *AzureSecurityGroupFixture {
if rules == nil {
rules = []network.SecurityRule{}
rules = []*armnetwork.SecurityRule{}
}
clonedRules := testutil.CloneInJSON(rules) // keep the original one immutable
f.sg.SecurityRules = &clonedRules
f.sg.Properties.SecurityRules = clonedRules
return f
}

func (f *AzureSecurityGroupFixture) Build() network.SecurityGroup {
return *f.sg
func (f *AzureSecurityGroupFixture) Build() *armnetwork.SecurityGroup {
return f.sg
}

// AzureAllowSecurityRuleFixture is a fixture for an allow security rule.
type AzureAllowSecurityRuleFixture struct {
rule *network.SecurityRule
rule *armnetwork.SecurityRule
}

func (f *AzureAllowSecurityRuleFixture) WithPriority(p int32) *AzureAllowSecurityRuleFixture {
f.rule.Priority = ptr.To(p)
f.rule.Properties.Priority = ptr.To(p)
return f
}

func (f *AzureAllowSecurityRuleFixture) WithDestination(prefixes ...string) *AzureAllowSecurityRuleFixture {
if len(prefixes) == 1 {
f.rule.DestinationAddressPrefix = ptr.To(prefixes[0])
f.rule.DestinationAddressPrefixes = nil
f.rule.Properties.DestinationAddressPrefix = ptr.To(prefixes[0])
f.rule.Properties.DestinationAddressPrefixes = nil
} else {
f.rule.DestinationAddressPrefix = nil
f.rule.DestinationAddressPrefixes = ptr.To(securitygroup.NormalizeSecurityRuleAddressPrefixes(prefixes))
f.rule.Properties.DestinationAddressPrefix = nil
f.rule.Properties.DestinationAddressPrefixes = to.SliceOfPtrs(securitygroup.NormalizeSecurityRuleAddressPrefixes(prefixes)...)
}

return f
}

func (f *AzureAllowSecurityRuleFixture) Build() network.SecurityRule {
return *f.rule
func (f *AzureAllowSecurityRuleFixture) Build() *armnetwork.SecurityRule {
return f.rule
}

// AzureDenyAllSecurityRuleFixture is a fixture for a deny-all security rule.
type AzureDenyAllSecurityRuleFixture struct {
rule *network.SecurityRule
rule *armnetwork.SecurityRule
}

func (f *AzureDenyAllSecurityRuleFixture) WithPriority(p int32) *AzureDenyAllSecurityRuleFixture {
f.rule.Priority = ptr.To(p)
f.rule.Properties.Priority = ptr.To(p)
return f
}

func (f *AzureDenyAllSecurityRuleFixture) WithDestination(prefixes ...string) *AzureDenyAllSecurityRuleFixture {
if len(prefixes) == 1 {
f.rule.DestinationAddressPrefix = ptr.To(prefixes[0])
f.rule.DestinationAddressPrefixes = nil
f.rule.Properties.DestinationAddressPrefix = ptr.To(prefixes[0])
f.rule.Properties.DestinationAddressPrefixes = nil
} else {
f.rule.DestinationAddressPrefix = nil
f.rule.DestinationAddressPrefixes = ptr.To(securitygroup.NormalizeSecurityRuleAddressPrefixes(prefixes))
f.rule.Properties.DestinationAddressPrefix = nil
f.rule.Properties.DestinationAddressPrefixes = to.SliceOfPtrs(securitygroup.NormalizeSecurityRuleAddressPrefixes(prefixes)...)
}
return f
}

func (f *AzureDenyAllSecurityRuleFixture) Build() network.SecurityRule {
return *f.rule
func (f *AzureDenyAllSecurityRuleFixture) Build() *armnetwork.SecurityRule {
return f.rule
}
15 changes: 12 additions & 3 deletions pkg/provider/azure_fakes.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,14 +28,16 @@ import (

"sigs.k8s.io/cloud-provider-azure/pkg/azclient"
"sigs.k8s.io/cloud-provider-azure/pkg/azclient/mock_azclient"
"sigs.k8s.io/cloud-provider-azure/pkg/azclient/privatezoneclient/mock_privatezoneclient"
"sigs.k8s.io/cloud-provider-azure/pkg/azclient/securitygroupclient/mock_securitygroupclient"
"sigs.k8s.io/cloud-provider-azure/pkg/azclient/virtualnetworklinkclient/mock_virtualnetworklinkclient"
"sigs.k8s.io/cloud-provider-azure/pkg/azureclients/diskclient/mockdiskclient"
"sigs.k8s.io/cloud-provider-azure/pkg/azureclients/interfaceclient/mockinterfaceclient"
"sigs.k8s.io/cloud-provider-azure/pkg/azureclients/loadbalancerclient/mockloadbalancerclient"
"sigs.k8s.io/cloud-provider-azure/pkg/azureclients/privatelinkserviceclient/mockprivatelinkserviceclient"
"sigs.k8s.io/cloud-provider-azure/pkg/azureclients/publicipclient/mockpublicipclient"
"sigs.k8s.io/cloud-provider-azure/pkg/azureclients/routeclient/mockrouteclient"
"sigs.k8s.io/cloud-provider-azure/pkg/azureclients/routetableclient/mockroutetableclient"
"sigs.k8s.io/cloud-provider-azure/pkg/azureclients/securitygroupclient/mocksecuritygroupclient"
"sigs.k8s.io/cloud-provider-azure/pkg/azureclients/snapshotclient/mocksnapshotclient"
"sigs.k8s.io/cloud-provider-azure/pkg/azureclients/subnetclient/mocksubnetclient"
"sigs.k8s.io/cloud-provider-azure/pkg/azureclients/vmclient/mockvmclient"
Expand Down Expand Up @@ -116,13 +118,20 @@ func GetTestCloud(ctrl *gomock.Controller) (az *Cloud) {
az.PublicIPAddressesClient = mockpublicipclient.NewMockInterface(ctrl)
az.RoutesClient = mockrouteclient.NewMockInterface(ctrl)
az.RouteTablesClient = mockroutetableclient.NewMockInterface(ctrl)
az.SecurityGroupsClient = mocksecuritygroupclient.NewMockInterface(ctrl)
az.SubnetsClient = mocksubnetclient.NewMockInterface(ctrl)
az.VirtualMachineScaleSetsClient = mockvmssclient.NewMockInterface(ctrl)
az.VirtualMachineScaleSetVMsClient = mockvmssvmclient.NewMockInterface(ctrl)
az.VirtualMachinesClient = mockvmclient.NewMockInterface(ctrl)
az.PrivateLinkServiceClient = mockprivatelinkserviceclient.NewMockInterface(ctrl)
az.ComputeClientFactory = mock_azclient.NewMockClientFactory(ctrl)
clientFactory := mock_azclient.NewMockClientFactory(ctrl)
az.ComputeClientFactory = clientFactory
az.NetworkClientFactory = clientFactory
securtyGrouptrack2Client := mock_securitygroupclient.NewMockInterface(ctrl)
clientFactory.EXPECT().GetSecurityGroupClient().Return(securtyGrouptrack2Client).AnyTimes()
mockPrivateDNSClient := mock_privatezoneclient.NewMockInterface(ctrl)
clientFactory.EXPECT().GetPrivateZoneClient().Return(mockPrivateDNSClient).AnyTimes()
virtualNetworkLinkClient := mock_virtualnetworklinkclient.NewMockInterface(ctrl)
clientFactory.EXPECT().GetVirtualNetworkLinkClient().Return(virtualNetworkLinkClient).AnyTimes()
az.AuthProvider = &azclient.AuthProvider{
ComputeCredential: mock_azclient.NewMockTokenCredential(ctrl),
}
Expand Down
Loading