Skip to content

Commit

Permalink
Merge pull request kubernetes#81500 from feiskyer/fix-81496
Browse files Browse the repository at this point in the history
Get location and subscriptionID from IMDS when useInstanceMetadata is true
  • Loading branch information
k8s-ci-robot authored Aug 17, 2019
2 parents 654df1d + bd85699 commit 667ea63
Show file tree
Hide file tree
Showing 9 changed files with 97 additions and 89 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,8 @@ type NetworkData struct {

// IPAddress represents IP address information.
type IPAddress struct {
PrivateIP string `json:"privateIPAddress"`
PublicIP string `json:"publicIPAddress"`
PrivateIP string `json:"privateIpAddress"`
PublicIP string `json:"publicIpAddress"`
}

// Subnet represents subnet information.
Expand All @@ -62,6 +62,7 @@ type Subnet struct {

// ComputeMetadata represents compute information
type ComputeMetadata struct {
Environment string `json:"azEnvironment,omitempty"`
SKU string `json:"sku,omitempty"`
Name string `json:"name,omitempty"`
Zone string `json:"zone,omitempty"`
Expand All @@ -72,6 +73,7 @@ type ComputeMetadata struct {
UpdateDomain string `json:"platformUpdateDomain,omitempty"`
ResourceGroup string `json:"resourceGroupName,omitempty"`
VMScaleSetName string `json:"vmScaleSetName,omitempty"`
SubscriptionID string `json:"subscriptionId,omitempty"`
}

// InstanceMetadata represents instance information.
Expand Down Expand Up @@ -111,7 +113,7 @@ func (ims *InstanceMetadataService) getInstanceMetadata(key string) (interface{}

q := req.URL.Query()
q.Add("format", "json")
q.Add("api-version", "2017-12-01")
q.Add("api-version", "2019-03-11")
req.URL.RawQuery = q.Encode()

client := &http.Client{}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -281,25 +281,26 @@ func (az *Cloud) InstanceID(ctx context.Context, name types.NodeName) (string, e
return "", fmt.Errorf("no credentials provided for Azure cloud provider")
}

// Get resource group name.
// Get resource group name and subscription ID.
resourceGroup := strings.ToLower(metadata.Compute.ResourceGroup)
subscriptionID := strings.ToLower(metadata.Compute.SubscriptionID)

// Compose instanceID based on nodeName for standard instance.
if az.VMType == vmTypeStandard {
return az.getStandardMachineID(resourceGroup, nodeName), nil
if metadata.Compute.VMScaleSetName == "" {
return az.getStandardMachineID(subscriptionID, resourceGroup, nodeName), nil
}

// Get scale set name and instanceID from vmName for vmss.
ssName, instanceID, err := extractVmssVMName(metadata.Compute.Name)
if err != nil {
if err == ErrorNotVmssInstance {
// Compose machineID for standard Node.
return az.getStandardMachineID(resourceGroup, nodeName), nil
return az.getStandardMachineID(subscriptionID, resourceGroup, nodeName), nil
}
return "", err
}
// Compose instanceID based on ssName and instanceID for vmss instance.
return az.getVmssMachineID(resourceGroup, ssName, instanceID), nil
return az.getVmssMachineID(subscriptionID, resourceGroup, ssName, instanceID), nil
}

return az.vmSet.GetInstanceIDByNodeName(nodeName)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ func setTestVirtualMachines(c *Cloud, vmList map[string]string, isDataDisksFull

func TestInstanceID(t *testing.T) {
cloud := getTestCloud()
cloud.Config.UseInstanceMetadata = true

testcases := []struct {
name string
Expand Down Expand Up @@ -120,7 +121,7 @@ func TestInstanceID(t *testing.T) {

mux := http.NewServeMux()
mux.Handle("/", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
fmt.Fprintf(w, fmt.Sprintf(`{"compute":{"name":"%s"}}`, test.metadataName))
fmt.Fprintf(w, fmt.Sprintf(`{"compute":{"name":"%s","subscriptionId":"subscription","resourceGroupName":"rg"}}`, test.metadataName))
}))
go func() {
http.Serve(listener, mux)
Expand Down Expand Up @@ -214,7 +215,7 @@ func TestInstanceShutdownByProviderID(t *testing.T) {
for _, test := range testcases {
cloud := getTestCloud()
setTestVirtualMachines(cloud, test.vmList, false)
providerID := "azure://" + cloud.getStandardMachineID("rg", test.nodeName)
providerID := "azure://" + cloud.getStandardMachineID("subscription", "rg", test.nodeName)
hasShutdown, err := cloud.InstanceShutdownByProviderID(context.Background(), providerID)
if test.expectError {
if err == nil {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -314,7 +314,7 @@ func (c *Cloud) GetAzureDiskLabels(diskURI string) (map[string]string, error) {
return nil, fmt.Errorf("failed to parse zone %v for AzureDisk %v: %v", zones, diskName, err)
}

zone := c.makeZone(zoneID)
zone := c.makeZone(c.Location, zoneID)
klog.V(4).Infof("Got zone %q for Azure disk %q", zone, diskName)
labels := map[string]string{
v1.LabelZoneRegion: c.Location,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,10 +66,10 @@ var nicResourceGroupRE = regexp.MustCompile(`.*/subscriptions/(?:.*)/resourceGro
var publicIPResourceGroupRE = regexp.MustCompile(`.*/subscriptions/(?:.*)/resourceGroups/(.+)/providers/Microsoft.Network/publicIPAddresses/(?:.*)`)

// getStandardMachineID returns the full identifier of a virtual machine.
func (az *Cloud) getStandardMachineID(resourceGroup, machineName string) string {
func (az *Cloud) getStandardMachineID(subscriptionID, resourceGroup, machineName string) string {
return fmt.Sprintf(
machineIDTemplate,
az.SubscriptionID,
subscriptionID,
strings.ToLower(resourceGroup),
machineName)
}
Expand Down Expand Up @@ -413,15 +413,15 @@ func (as *availabilitySet) GetZoneByNodeName(name string) (cloudprovider.Zone, e
return cloudprovider.Zone{}, fmt.Errorf("failed to parse zone %q: %v", zones, err)
}

failureDomain = as.makeZone(zoneID)
failureDomain = as.makeZone(to.String(vm.Location), zoneID)
} else {
// Availability zone is not used for the node, falling back to fault domain.
failureDomain = strconv.Itoa(int(*vm.VirtualMachineProperties.InstanceView.PlatformFaultDomain))
}

zone := cloudprovider.Zone{
FailureDomain: failureDomain,
Region: *(vm.Location),
Region: to.String(vm.Location),
}
return zone, nil
}
Expand Down
66 changes: 0 additions & 66 deletions staging/src/k8s.io/legacy-cloud-providers/azure/azure_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,6 @@ import (
"context"
"fmt"
"math"
"net"
"net/http"
"strings"
"testing"

Expand Down Expand Up @@ -1726,70 +1724,6 @@ func validateEmptyConfig(t *testing.T, config string) {
}
}

func TestGetZone(t *testing.T) {
cloud := &Cloud{
Config: Config{
Location: "eastus",
UseInstanceMetadata: true,
},
}
testcases := []struct {
name string
zone string
faultDomain string
expected string
}{
{
name: "GetZone should get real zone if only node's zone is set",
zone: "1",
expected: "eastus-1",
},
{
name: "GetZone should get real zone if both node's zone and FD are set",
zone: "1",
faultDomain: "99",
expected: "eastus-1",
},
{
name: "GetZone should get faultDomain if node's zone isn't set",
faultDomain: "99",
expected: "99",
},
}

for _, test := range testcases {
listener, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Errorf("Test [%s] unexpected error: %v", test.name, err)
}

mux := http.NewServeMux()
mux.Handle("/", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
fmt.Fprintf(w, fmt.Sprintf(`{"compute":{"zone":"%s", "platformFaultDomain":"%s"}}`, test.zone, test.faultDomain))
}))
go func() {
http.Serve(listener, mux)
}()
defer listener.Close()

cloud.metadata, err = NewInstanceMetadataService("http://" + listener.Addr().String() + "/")
if err != nil {
t.Errorf("Test [%s] unexpected error: %v", test.name, err)
}

zone, err := cloud.GetZone(context.Background())
if err != nil {
t.Errorf("Test [%s] unexpected error: %v", test.name, err)
}
if zone.FailureDomain != test.expected {
t.Errorf("Test [%s] unexpected zone: %s, expected %q", test.name, zone.FailureDomain, test.expected)
}
if zone.Region != cloud.Location {
t.Errorf("Test [%s] unexpected region: %s, expected: %s", test.name, zone.Region, cloud.Location)
}
}
}

func TestGetNodeNameByProviderID(t *testing.T) {
az := getTestCloud()
providers := []struct {
Expand Down
8 changes: 4 additions & 4 deletions staging/src/k8s.io/legacy-cloud-providers/azure/azure_vmss.go
Original file line number Diff line number Diff line change
Expand Up @@ -280,15 +280,15 @@ func (ss *scaleSet) GetZoneByNodeName(name string) (cloudprovider.Zone, error) {
return cloudprovider.Zone{}, fmt.Errorf("failed to parse zone %q: %v", zones, err)
}

failureDomain = ss.makeZone(zoneID)
failureDomain = ss.makeZone(to.String(vm.Location), zoneID)
} else if vm.InstanceView != nil && vm.InstanceView.PlatformFaultDomain != nil {
// Availability zone is not used for the node, falling back to fault domain.
failureDomain = strconv.Itoa(int(*vm.InstanceView.PlatformFaultDomain))
}

return cloudprovider.Zone{
FailureDomain: failureDomain,
Region: *vm.Location,
Region: to.String(vm.Location),
}, nil
}

Expand Down Expand Up @@ -399,10 +399,10 @@ func (ss *scaleSet) getPrimaryInterfaceID(machine compute.VirtualMachineScaleSet
}

// getVmssMachineID returns the full identifier of a vmss virtual machine.
func (az *Cloud) getVmssMachineID(resourceGroup, scaleSetName, instanceID string) string {
func (az *Cloud) getVmssMachineID(subscriptionID, resourceGroup, scaleSetName, instanceID string) string {
return fmt.Sprintf(
vmssMachineIDTemplate,
az.SubscriptionID,
subscriptionID,
strings.ToLower(resourceGroup),
scaleSetName,
instanceID)
Expand Down
10 changes: 6 additions & 4 deletions staging/src/k8s.io/legacy-cloud-providers/azure/azure_zones.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,8 @@ import (
)

// makeZone returns the zone value in format of <region>-<zone-id>.
func (az *Cloud) makeZone(zoneID int) string {
return fmt.Sprintf("%s-%d", strings.ToLower(az.Location), zoneID)
func (az *Cloud) makeZone(location string, zoneID int) string {
return fmt.Sprintf("%s-%d", strings.ToLower(location), zoneID)
}

// isAvailabilityZone returns true if the zone is in format of <region>-<zone-id>.
Expand All @@ -57,24 +57,26 @@ func (az *Cloud) GetZone(ctx context.Context) (cloudprovider.Zone, error) {
}

if metadata.Compute == nil {
az.metadata.imsCache.Delete(metadataCacheKey)
return cloudprovider.Zone{}, fmt.Errorf("failure of getting compute information from instance metadata")
}

zone := ""
location := metadata.Compute.Location
if metadata.Compute.Zone != "" {
zoneID, err := strconv.Atoi(metadata.Compute.Zone)
if err != nil {
return cloudprovider.Zone{}, fmt.Errorf("failed to parse zone ID %q: %v", metadata.Compute.Zone, err)
}
zone = az.makeZone(zoneID)
zone = az.makeZone(location, zoneID)
} else {
klog.V(3).Infof("Availability zone is not enabled for the node, falling back to fault domain")
zone = metadata.Compute.FaultDomain
}

return cloudprovider.Zone{
FailureDomain: zone,
Region: az.Location,
Region: location,
}, nil
}
// if UseInstanceMetadata is false, get Zone name by calling ARM
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,10 @@ limitations under the License.
package azure

import (
"context"
"fmt"
"net"
"net/http"
"testing"
)

Expand Down Expand Up @@ -71,3 +75,67 @@ func TestGetZoneID(t *testing.T) {
}
}
}

func TestGetZone(t *testing.T) {
cloud := &Cloud{
Config: Config{
Location: "eastus",
UseInstanceMetadata: true,
},
}
testcases := []struct {
name string
zone string
faultDomain string
expected string
}{
{
name: "GetZone should get real zone if only node's zone is set",
zone: "1",
expected: "eastus-1",
},
{
name: "GetZone should get real zone if both node's zone and FD are set",
zone: "1",
faultDomain: "99",
expected: "eastus-1",
},
{
name: "GetZone should get faultDomain if node's zone isn't set",
faultDomain: "99",
expected: "99",
},
}

for _, test := range testcases {
listener, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Errorf("Test [%s] unexpected error: %v", test.name, err)
}

mux := http.NewServeMux()
mux.Handle("/", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
fmt.Fprintf(w, fmt.Sprintf(`{"compute":{"zone":"%s", "platformFaultDomain":"%s", "location":"eastus"}}`, test.zone, test.faultDomain))
}))
go func() {
http.Serve(listener, mux)
}()
defer listener.Close()

cloud.metadata, err = NewInstanceMetadataService("http://" + listener.Addr().String() + "/")
if err != nil {
t.Errorf("Test [%s] unexpected error: %v", test.name, err)
}

zone, err := cloud.GetZone(context.Background())
if err != nil {
t.Errorf("Test [%s] unexpected error: %v", test.name, err)
}
if zone.FailureDomain != test.expected {
t.Errorf("Test [%s] unexpected zone: %s, expected %q", test.name, zone.FailureDomain, test.expected)
}
if zone.Region != cloud.Location {
t.Errorf("Test [%s] unexpected region: %s, expected: %s", test.name, zone.Region, cloud.Location)
}
}
}

0 comments on commit 667ea63

Please sign in to comment.