Skip to content

Commit

Permalink
Enable multi-tenant authentication with auxiliary token provider
Browse files Browse the repository at this point in the history
  • Loading branch information
zarvd committed May 20, 2024
1 parent 5ba3186 commit 43112fb
Show file tree
Hide file tree
Showing 13 changed files with 754 additions and 585 deletions.
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ require (
k8s.io/klog/v2 v2.120.1
k8s.io/kubelet v0.30.1
k8s.io/utils v0.0.0-20231127182322-b307cd553661
sigs.k8s.io/cloud-provider-azure/pkg/azclient v0.0.19
sigs.k8s.io/cloud-provider-azure/pkg/azclient v0.0.21
sigs.k8s.io/cloud-provider-azure/pkg/azclient/configloader v0.0.11
sigs.k8s.io/yaml v1.4.0
)
Expand Down
161 changes: 2 additions & 159 deletions go.sum

Large diffs are not rendered by default.

70 changes: 70 additions & 0 deletions pkg/azureclients/armauth/multi_tenant_token_provider.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
package armauth

import (
"context"
"time"

"github.com/Azure/azure-sdk-for-go/sdk/azcore"
"github.com/Azure/azure-sdk-for-go/sdk/azcore/policy"
"github.com/go-logr/logr"
)

// MultiTenantTokenProvider is the track1 multi-tenant token provider wrapper for track2 implementation.
type MultiTenantTokenProvider struct {
logger logr.Logger
primaryCredential azcore.TokenCredential
auxiliaryCredentials []azcore.TokenCredential
timeout time.Duration
scope string
}

func NewMultiTenantTokenProvider(
logger logr.Logger,
primaryCredential azcore.TokenCredential,
auxiliaryCredentials []azcore.TokenCredential,
scope string,
) (*MultiTenantTokenProvider, error) {
return &MultiTenantTokenProvider{
logger: logger,
primaryCredential: primaryCredential,
auxiliaryCredentials: auxiliaryCredentials,
timeout: 10 * time.Second,
scope: scope,
}, nil
}

func (p *MultiTenantTokenProvider) PrimaryOAuthToken() string {
p.logger.V(4).Info("Fetching primary oauth token")
ctx, cancel := context.WithTimeout(context.Background(), p.timeout)
defer cancel()

token, err := p.primaryCredential.GetToken(ctx, policy.TokenRequestOptions{
Scopes: []string{p.scope},
})
if err != nil {
p.logger.Error(err, "Failed to fetch primary OAuth token")
return ""
}
return token.Token
}

func (p *MultiTenantTokenProvider) AuxiliaryOAuthTokens() []string {
p.logger.V(4).Info("Fetching auxiliary oauth token", "num-credentials", len(p.auxiliaryCredentials))
ctx, cancel := context.WithTimeout(context.Background(), p.timeout)
defer cancel()

var tokens []string
for _, cred := range p.auxiliaryCredentials {
token, err := cred.GetToken(ctx, policy.TokenRequestOptions{
Scopes: []string{p.scope},
})
if err != nil {
p.logger.Error(err, "Failed to fetch auxiliary OAuth token")
return nil
}

tokens = append(tokens, token.Token)
}

return tokens
}
47 changes: 47 additions & 0 deletions pkg/azureclients/armauth/token_provider.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
package armauth

import (
"context"
"time"

"github.com/Azure/azure-sdk-for-go/sdk/azcore"
"github.com/Azure/azure-sdk-for-go/sdk/azcore/policy"
"github.com/go-logr/logr"
)

// TokenProvider is the track1 token provider wrapper for track2 implementation.
type TokenProvider struct {
logger logr.Logger
credential azcore.TokenCredential
timeout time.Duration
scope string
}

func NewTokenProvider(
logger logr.Logger,
credential azcore.TokenCredential,
scope string,
) (*TokenProvider, error) {
return &TokenProvider{
logger: logger,
credential: credential,
timeout: 10 * time.Second,
scope: scope,
}, nil
}

func (p *TokenProvider) OAuthToken() string {
p.logger.V(4).Info("Fetching OAuth token")
ctx, cancel := context.WithTimeout(context.Background(), p.timeout)
defer cancel()

token, err := p.credential.GetToken(ctx, policy.TokenRequestOptions{
Scopes: []string{p.scope},
})
if err != nil {
p.logger.Error(err, "Failed to fetch OAuth token")
return ""
}
p.logger.V(4).Info("Fetched OAuth token successfully", "token", token.Token)
return token.Token
}
32 changes: 17 additions & 15 deletions pkg/provider/azure.go
Original file line number Diff line number Diff line change
Expand Up @@ -693,18 +693,19 @@ func (az *Cloud) InitializeCloudFromConfig(ctx context.Context, config *Config,
return nil
}

var authProvider *azclient.AuthProvider
authProvider, err = azclient.NewAuthProvider(&az.ARMClientConfig, &az.AzureAuthConfig.AzureAuthConfig)
if err != nil {
return err
}
// If uses network resources in different AAD Tenant, then prepare corresponding Service Principal Token for VM/VMSS client and network resources client
multiTenantServicePrincipalToken, networkResourceServicePrincipalToken, err := az.getAuthTokenInMultiTenantEnv(servicePrincipalToken)
multiTenantServicePrincipalToken, networkResourceServicePrincipalToken, err := az.getAuthTokenInMultiTenantEnv(servicePrincipalToken, authProvider)
if err != nil {
return err
}
az.configAzureClients(servicePrincipalToken, multiTenantServicePrincipalToken, networkResourceServicePrincipalToken)

if az.ComputeClientFactory == nil {
authProvider, err := azclient.NewAuthProvider(&az.ARMClientConfig, &az.AzureAuthConfig.AzureAuthConfig)
if err != nil {
return err
}
var cred azcore.TokenCredential
if authProvider.IsMultiTenantModeEnabled() {
multiTenantCred := authProvider.GetMultiTenantIdentity()
Expand Down Expand Up @@ -888,21 +889,21 @@ func (az *Cloud) setLBDefaults(config *Config) error {
return nil
}

func (az *Cloud) getAuthTokenInMultiTenantEnv(_ *adal.ServicePrincipalToken) (*adal.MultiTenantServicePrincipalToken, *adal.ServicePrincipalToken, error) {
func (az *Cloud) getAuthTokenInMultiTenantEnv(_ *adal.ServicePrincipalToken, authProvider *azclient.AuthProvider) (adal.MultitenantOAuthTokenProvider, adal.OAuthTokenProvider, error) {
var err error
var multiTenantServicePrincipalToken *adal.MultiTenantServicePrincipalToken
var networkResourceServicePrincipalToken *adal.ServicePrincipalToken
var multiTenantOAuthToken adal.MultitenantOAuthTokenProvider
var networkResourceServicePrincipalToken adal.OAuthTokenProvider
if az.Config.UsesNetworkResourceInDifferentTenant() {
multiTenantServicePrincipalToken, err = ratelimitconfig.GetMultiTenantServicePrincipalToken(&az.Config.AzureAuthConfig, &az.Environment)
multiTenantOAuthToken, err = ratelimitconfig.GetMultiTenantServicePrincipalToken(&az.Config.AzureAuthConfig, &az.Environment, authProvider)
if err != nil {
return nil, nil, err
}
networkResourceServicePrincipalToken, err = ratelimitconfig.GetNetworkResourceServicePrincipalToken(&az.Config.AzureAuthConfig, &az.Environment)
networkResourceServicePrincipalToken, err = ratelimitconfig.GetNetworkResourceServicePrincipalToken(&az.Config.AzureAuthConfig, &az.Environment, authProvider)
if err != nil {
return nil, nil, err
}
}
return multiTenantServicePrincipalToken, networkResourceServicePrincipalToken, nil
return multiTenantOAuthToken, networkResourceServicePrincipalToken, nil
}

func (az *Cloud) setCloudProviderBackoffDefaults(config *Config) wait.Backoff {
Expand Down Expand Up @@ -947,8 +948,8 @@ func (az *Cloud) setCloudProviderBackoffDefaults(config *Config) wait.Backoff {

func (az *Cloud) configAzureClients(
servicePrincipalToken *adal.ServicePrincipalToken,
multiTenantServicePrincipalToken *adal.MultiTenantServicePrincipalToken,
networkResourceServicePrincipalToken *adal.ServicePrincipalToken) {
multiTenantOAuthTokenProvider adal.MultitenantOAuthTokenProvider,
networkResourceServicePrincipalToken adal.OAuthTokenProvider) {
azClientConfig := az.getAzureClientConfig(servicePrincipalToken)

// Prepare AzureClientConfig for all azure clients
Expand Down Expand Up @@ -981,8 +982,9 @@ func (az *Cloud) configAzureClients(
zoneClientConfig := azClientConfig.WithRateLimiter(nil)

// If uses network resources in different AAD Tenant, update Authorizer for VM/VMSS/VMAS client config
if multiTenantServicePrincipalToken != nil {
multiTenantServicePrincipalTokenAuthorizer := autorest.NewMultiTenantServicePrincipalTokenAuthorizer(multiTenantServicePrincipalToken)
if multiTenantOAuthTokenProvider != nil {
multiTenantServicePrincipalTokenAuthorizer := autorest.NewMultiTenantServicePrincipalTokenAuthorizer(multiTenantOAuthTokenProvider)

vmClientConfig.Authorizer = multiTenantServicePrincipalTokenAuthorizer
vmssClientConfig.Authorizer = multiTenantServicePrincipalTokenAuthorizer
vmssVMClientConfig.Authorizer = multiTenantServicePrincipalTokenAuthorizer
Expand Down
8 changes: 8 additions & 0 deletions pkg/provider/azure_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2089,6 +2089,7 @@ func TestNewCloudFromJSON(t *testing.T) {
// Test Backoff and Rate Limit defaults (json)
func TestCloudDefaultConfigFromJSON(t *testing.T) {
config := `{
"tenantId": "--tenant-id--",
"aadClientId": "--aad-client-id--",
"aadClientSecret": "--aad-client-secret--"
}`
Expand All @@ -2099,6 +2100,7 @@ func TestCloudDefaultConfigFromJSON(t *testing.T) {
// Test Backoff and Rate Limit defaults (yaml)
func TestCloudDefaultConfigFromYAML(t *testing.T) {
config := `
tenantId: --tenant-id--
aadClientId: --aad-client-id--
aadClientSecret: --aad-client-secret--
`
Expand Down Expand Up @@ -2294,9 +2296,15 @@ func getCloudFromConfig(t *testing.T, config string) *Cloud {
mockZoneClient := az.ZoneClient.(*mockzoneclient.MockInterface)
mockZoneClient.EXPECT().GetZones(gomock.Any(), gomock.Any()).Return(map[string][]string{"eastus": {"1", "2", "3"}}, nil)

// Skip AAD client cert path validation since it will read the file from the path
aadCertPath := c.AADClientCertPath
c.AADClientCertPath = ""

err = az.InitializeCloudFromConfig(context.Background(), c, false, true)
assert.NoError(t, err)

az.AADClientCertPath = aadCertPath

return az
}

Expand Down
Loading

0 comments on commit 43112fb

Please sign in to comment.