diff --git a/pkg/azurefile/azure.go b/pkg/azurefile/azure.go index 1dc032bf6b..10e256a8c6 100644 --- a/pkg/azurefile/azure.go +++ b/pkg/azurefile/azure.go @@ -120,7 +120,18 @@ func getCloudProvider(ctx context.Context, kubeconfig, nodeID, secretName, secre } } else { config.UserAgent = userAgent - if err = az.InitializeCloudFromConfig(context.TODO(), config, fromSecret, false); err != nil { + // these environment variables are injected by workload identity webhook + if tenantID := os.Getenv("AZURE_TENANT_ID"); tenantID != "" { + config.TenantID = tenantID + } + if clientID := os.Getenv("AZURE_CLIENT_ID"); clientID != "" { + config.AADClientID = clientID + } + if federatedTokenFile := os.Getenv("AZURE_FEDERATED_TOKEN_FILE"); federatedTokenFile != "" { + config.AADFederatedTokenFile = federatedTokenFile + config.UseFederatedWorkloadIdentityExtension = true + } + if err = az.InitializeCloudFromConfig(ctx, config, fromSecret, false); err != nil { klog.Warningf("InitializeCloudFromConfig failed with error: %v", err) } } diff --git a/pkg/azurefile/azure_test.go b/pkg/azurefile/azure_test.go index fa7abf1b34..0446f2eb0f 100644 --- a/pkg/azurefile/azure_test.go +++ b/pkg/azurefile/azure_test.go @@ -86,13 +86,18 @@ users: }() tests := []struct { - desc string - createFakeCredFile bool - createFakeKubeConfig bool - kubeconfig string - userAgent string - allowEmptyCloudConfig bool - expectedErr testutil.TestError + desc string + createFakeCredFile bool + createFakeKubeConfig bool + setFederatedWorkloadIdentityEnv bool + kubeconfig string + userAgent string + allowEmptyCloudConfig bool + aadFederatedTokenFile string + useFederatedWorkloadIdentityExtension bool + aadClientID string + tenantID string + expectedErr testutil.TestError }{ { desc: "out of cluster, no kubeconfig, no credential file", @@ -137,6 +142,19 @@ users: allowEmptyCloudConfig: true, expectedErr: testutil.TestError{}, }, + { + desc: "[success] get azure client with workload identity", + createFakeKubeConfig: true, + createFakeCredFile: true, + setFederatedWorkloadIdentityEnv: true, + kubeconfig: fakeKubeConfig, + userAgent: "useragent", + useFederatedWorkloadIdentityExtension: true, + aadFederatedTokenFile: "fake-token-file", + aadClientID: "fake-client-id", + tenantID: "fake-tenant-id", + expectedErr: testutil.TestError{}, + }, } for _, test := range tests { @@ -145,7 +163,7 @@ users: t.Error(err) } defer func() { - if err := os.Remove(fakeKubeConfig); err != nil { + if err := os.Remove(fakeKubeConfig); err != nil && !os.IsNotExist(err) { t.Error(err) } }() @@ -159,7 +177,7 @@ users: t.Error(err) } defer func() { - if err := os.Remove(fakeCredFile); err != nil { + if err := os.Remove(fakeCredFile); err != nil && !os.IsNotExist(err) { t.Error(err) } }() @@ -172,6 +190,12 @@ users: } os.Setenv(DefaultAzureCredentialFileEnv, fakeCredFile) } + if test.setFederatedWorkloadIdentityEnv { + t.Setenv("AZURE_TENANT_ID", test.tenantID) + t.Setenv("AZURE_CLIENT_ID", test.aadClientID) + t.Setenv("AZURE_FEDERATED_TOKEN_FILE", test.aadFederatedTokenFile) + } + cloud, err := getCloudProvider(context.Background(), test.kubeconfig, "", "", "", test.userAgent, test.allowEmptyCloudConfig, false, 5, 10) if !testutil.AssertError(err, &test.expectedErr) && !strings.Contains(err.Error(), test.expectedErr.DefaultError.Error()) { t.Errorf("desc: %s,\n input: %q, getCloudProvider err: %v, expectedErr: %v", test.desc, test.kubeconfig, err, test.expectedErr) @@ -180,6 +204,10 @@ users: t.Errorf("return value of getCloudProvider should not be nil even there is error") } else { assert.Equal(t, cloud.UserAgent, test.userAgent) + assert.Equal(t, cloud.AADFederatedTokenFile, test.aadFederatedTokenFile) + assert.Equal(t, cloud.UseFederatedWorkloadIdentityExtension, test.useFederatedWorkloadIdentityExtension) + assert.Equal(t, cloud.AADClientID, test.aadClientID) + assert.Equal(t, cloud.TenantID, test.tenantID) } } }