Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add DefaultAzureCredential struct #15759

Merged
merged 3 commits into from
Oct 22, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions sdk/azidentity/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,8 @@
* `AuthenticationFailedError.RawResponse()` returns the HTTP response motivating the error,
if available

### Other Changes
* `NewDefaultAzureCredential()` returns `*DefaultAzureCredential` instead of `*ChainedTokenCredential`

## 0.11.0 (2021-09-08)
### Breaking Changes
Expand Down
35 changes: 20 additions & 15 deletions sdk/azidentity/azidentity_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,25 +35,30 @@ const (
customHostString = "https://custommock.com/"
)

// Set AZURE_AUTHORITY_HOST for the duration of a test. Restore its prior value
// after the test completes. Prevents tests which set the variable from breaking live
// tests in sovereign clouds. Obviated by 1.17's T.Setenv
func setEnvAuthorityHost(host string, t *testing.T) {
originalHost := os.Getenv("AZURE_AUTHORITY_HOST")
err := os.Setenv("AZURE_AUTHORITY_HOST", host)
if err != nil {
t.Fatalf("Unexpected error setting AZURE_AUTHORITY_HOST: %v", err)
// Set environment variables for the duration of a test. Restore their prior values
// after the test completes. Obviated by 1.17's T.Setenv
func setEnvironmentVariables(t *testing.T, vars map[string]string) {
priorValues := make(map[string]string, len(vars))
for k, v := range vars {
priorValues[k] = os.Getenv(k)
err := os.Setenv(k, v)
if err != nil {
t.Fatalf("Unexpected error setting %s: %v", k, err)
}
}

t.Cleanup(func() {
err = os.Setenv("AZURE_AUTHORITY_HOST", originalHost)
if err != nil {
t.Fatalf("Unexpected error resetting AZURE_AUTHORITY_HOST: %v", err)
for k, v := range priorValues {
err := os.Setenv(k, v)
if err != nil {
t.Fatalf("Unexpected error resetting %s: %v", k, err)
}
}
})
}

func Test_SetEnvAuthorityHost(t *testing.T) {
setEnvAuthorityHost(envHostString, t)
setEnvironmentVariables(t, map[string]string{"AZURE_AUTHORITY_HOST": envHostString})
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems that AZURE_AUTHORITY_HOST should be a constant. Pre-existing so you don't have to fix it in this PR, but good to clean up (probably worth checking for other cases).

authorityHost, err := setAuthorityHost("")
if err != nil {
t.Fatal(err)
Expand All @@ -64,7 +69,7 @@ func Test_SetEnvAuthorityHost(t *testing.T) {
}

func Test_CustomAuthorityHost(t *testing.T) {
setEnvAuthorityHost(envHostString, t)
setEnvironmentVariables(t, map[string]string{"AZURE_AUTHORITY_HOST": envHostString})
authorityHost, err := setAuthorityHost(customHostString)
if err != nil {
t.Fatal(err)
Expand All @@ -76,7 +81,7 @@ func Test_CustomAuthorityHost(t *testing.T) {
}

func Test_DefaultAuthorityHost(t *testing.T) {
setEnvAuthorityHost("", t)
setEnvironmentVariables(t, map[string]string{"AZURE_AUTHORITY_HOST": ""})
authorityHost, err := setAuthorityHost("")
if err != nil {
t.Fatal(err)
Expand All @@ -87,7 +92,7 @@ func Test_DefaultAuthorityHost(t *testing.T) {
}

func Test_NonHTTPSAuthorityHost(t *testing.T) {
setEnvAuthorityHost("", t)
setEnvironmentVariables(t, map[string]string{"AZURE_AUTHORITY_HOST": ""})
authorityHost, err := setAuthorityHost("http://foo.com")
if err == nil {
t.Fatal("Expected an error but did not receive one.")
Expand Down
2 changes: 1 addition & 1 deletion sdk/azidentity/client_secret_credential_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ import (
const (
tenantID = "expected-tenant"
badTenantID = "bad_tenant"
clientID = "expected_client"
clientID = "expected-client-id"
secret = "secret"
wrongSecret = "wrong_secret"
tokenValue = "new_token"
Expand Down
30 changes: 23 additions & 7 deletions sdk/azidentity/default_azure_credential.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,10 @@
package azidentity

import (
"context"

"github.com/Azure/azure-sdk-for-go/sdk/azcore"
"github.com/Azure/azure-sdk-for-go/sdk/internal/log"
"github.com/Azure/azure-sdk-for-go/sdk/azcore/policy"
)

const (
Expand All @@ -23,13 +25,19 @@ type DefaultAzureCredentialOptions struct {
AuthorityHost AuthorityHost
}

// NewDefaultAzureCredential provides a default ChainedTokenCredential configuration for applications that will be deployed to Azure. The following credential
// types will be tried, in the following order:
// DefaultAzureCredential is a default credential chain for applications that will be deployed to Azure.
// It combines credentials suitable for deployed applications with credentials suitable in local development.
// It attempts to authenticate with each of these credential types, in the following order:
// - EnvironmentCredential
// - ManagedIdentityCredential
// - AzureCLICredential
// Consult the documentation for these credential types for more information on how they attempt authentication.
func NewDefaultAzureCredential(options *DefaultAzureCredentialOptions) (*ChainedTokenCredential, error) {
// Consult the documentation for these credential types for more information on how they authenticate.
type DefaultAzureCredential struct {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing doc comment

chain *ChainedTokenCredential
}

// NewDefaultAzureCredential creates a default credential chain for applications that will be deployed to Azure.
func NewDefaultAzureCredential(options *DefaultAzureCredentialOptions) (*DefaultAzureCredential, error) {
var creds []azcore.TokenCredential
errMsg := ""

Expand Down Expand Up @@ -67,6 +75,14 @@ func NewDefaultAzureCredential(options *DefaultAzureCredentialOptions) (*Chained
logCredentialError(err.credentialType, err)
return nil, err
}
log.Write(EventCredential, "Azure Identity => NewDefaultAzureCredential() invoking NewChainedTokenCredential()")
return NewChainedTokenCredential(creds, nil)
chain, err := NewChainedTokenCredential(creds, nil)
if err != nil {
return nil, err
}
return &DefaultAzureCredential{chain: chain}, nil
}

// GetToken attempts to acquire a token from each of the default chain's credentials, stopping when one provides a token.
func (c *DefaultAzureCredential) GetToken(ctx context.Context, opts policy.TokenRequestOptions) (token *azcore.AccessToken, err error) {
return c.chain.GetToken(ctx, opts)
}
40 changes: 14 additions & 26 deletions sdk/azidentity/default_azure_credential_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,38 +4,26 @@
package azidentity

import (
"errors"
"context"
"testing"
)

const (
lengthOfChainOneExcluded = 2
lengthOfChainFull = 3
"github.com/Azure/azure-sdk-for-go/sdk/azcore/policy"
"github.com/Azure/azure-sdk-for-go/sdk/internal/mock"
)

func TestDefaultAzureCredential_NilOptions(t *testing.T) {
resetEnvironmentVarsForTest()
err := initEnvironmentVarsForTest()
func TestDefaultAzureCredential_GetTokenSuccess(t *testing.T) {
env := map[string]string{"AZURE_TENANT_ID": tenantID, "AZURE_CLIENT_ID": clientID, "AZURE_CLIENT_SECRET": secret}
setEnvironmentVariables(t, env)
srv, close := mock.NewTLSServer()
defer close()
srv.AppendResponse(mock.WithBody([]byte(accessTokenRespSuccess)))

cred, err := NewDefaultAzureCredential(&DefaultAzureCredentialOptions{AuthorityHost: AuthorityHost(srv.URL()), ClientOptions: policy.ClientOptions{Transport: srv}})
if err != nil {
t.Fatalf("Unexpected error when initializing environment variables: %v", err)
t.Fatalf("Unable to create credential. Received: %v", err)
}
cred, err := NewDefaultAzureCredential(nil)
_, err = cred.GetToken(context.Background(), policy.TokenRequestOptions{Scopes: []string{"scope"}})
if err != nil {
t.Fatalf("Did not expect to receive an error in creating the credential")
}
c := newManagedIdentityClient(&ManagedIdentityCredentialOptions{})
// if the test is running in a MSI environment then the length of sources would be two since it will include environment credential and managed identity credential
if msiType, err := c.getMSIType(); !(msiType == msiTypeUnavailable || msiType == msiTypeUnknown) {
if len(cred.sources) != lengthOfChainFull {
t.Fatalf("Length of ChainedTokenCredential sources for DefaultAzureCredential. Expected: %d, Received: %d", lengthOfChainFull, len(cred.sources))
}
//if a credential unavailable error is received or msiType is unknown then only the environment credential will be added
} else if unavailableErr := (*CredentialUnavailableError)(nil); errors.As(err, &unavailableErr) || msiType == msiTypeUnknown {
if len(cred.sources) != lengthOfChainOneExcluded {
t.Fatalf("Length of ChainedTokenCredential sources for DefaultAzureCredential. Expected: %d, Received: %d", lengthOfChainOneExcluded, len(cred.sources))
}
// if there is some other unexpected error then we fail here
} else if err != nil {
t.Fatalf("Received an error when trying to determine MSI type: %v", err)
t.Fatalf("GetToken error: %v", err)
}
}