Skip to content

Commit

Permalink
Increase azidentity test coverage (#21345)
Browse files Browse the repository at this point in the history
  • Loading branch information
chlowell authored Aug 10, 2023
1 parent 07c7a00 commit 514985a
Show file tree
Hide file tree
Showing 5 changed files with 187 additions and 0 deletions.
7 changes: 7 additions & 0 deletions sdk/azidentity/client_assertion_credential_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,13 @@ func TestClientAssertionCredentialCallbackError(t *testing.T) {
}
}

func TestClientAssertionCredentialNilCallback(t *testing.T) {
_, err := NewClientAssertionCredential(fakeTenantID, fakeClientID, nil, nil)
if err == nil {
t.Fatal("expected an error")
}
}

func TestClientAssertionCredential_Live(t *testing.T) {
data, err := os.ReadFile(liveSP.pemPath)
if err != nil {
Expand Down
25 changes: 25 additions & 0 deletions sdk/azidentity/environment_credential_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (
"context"
"fmt"
"os"
"path/filepath"
"reflect"
"strings"
"testing"
Expand Down Expand Up @@ -93,6 +94,30 @@ func TestEnvironmentCredential_ClientSecretSet(t *testing.T) {
}
}

func TestEnvironmentCredential_CertificateErrors(t *testing.T) {
resetEnvironmentVarsForTest()
for _, test := range []struct {
name, path string
}{
{"file doesn't exist", filepath.Join(t.TempDir(), t.Name())},
{"invalid file", "testdata/certificate-wrong-key.pem"},
} {
t.Run(test.name, func(t *testing.T) {
for k, v := range map[string]string{
azureClientID: fakeClientID,
azureClientCertificatePath: test.path,
azureTenantID: fakeTenantID,
} {
t.Setenv(k, v)
_, err := NewEnvironmentCredential(nil)
if err == nil {
t.Fatal("expected an error")
}
}
})
}
}

func TestEnvironmentCredential_ClientCertificatePathSet(t *testing.T) {
resetEnvironmentVarsForTest()
err := os.Setenv(azureTenantID, fakeTenantID)
Expand Down
61 changes: 61 additions & 0 deletions sdk/azidentity/managed_identity_client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import (

"github.com/Azure/azure-sdk-for-go/sdk/azcore"
"github.com/Azure/azure-sdk-for-go/sdk/azcore/policy"
"github.com/Azure/azure-sdk-for-go/sdk/internal/log"
)

type userAgentValidatingPolicy struct {
Expand Down Expand Up @@ -73,3 +74,63 @@ func TestManagedIdentityClient_ApplicationID(t *testing.T) {
t.Fatal(err)
}
}

func TestManagedIdentityClient_UserAssignedIDWarning(t *testing.T) {
for _, test := range []struct {
name string
createRequest func(*managedIdentityClient) error
}{
{
name: "Azure Arc",
createRequest: func(client *managedIdentityClient) error {
_, err := client.createAzureArcAuthRequest(context.Background(), client.id, []string{liveTestScope}, "key")
return err
},
},
{
name: "Cloud Shell",
createRequest: func(client *managedIdentityClient) error {
_, err := client.createCloudShellAuthRequest(context.Background(), client.id, []string{liveTestScope})
return err
},
},
{
name: "Service Fabric",
createRequest: func(client *managedIdentityClient) error {
_, err := client.createServiceFabricAuthRequest(context.Background(), client.id, []string{liveTestScope})
return err
},
},
} {
for _, id := range []ManagedIDKind{ClientID(fakeClientID), ResourceID(fakeResourceID)} {
s := "-ClientID"
if id.String() == fakeResourceID {
s = "-ResourceID"
}
t.Run(test.name+s, func(t *testing.T) {
msgs := []string{}
log.SetListener(func(event log.Event, msg string) {
if event == EventAuthentication {
msgs = append(msgs, msg)
}
})
client, err := newManagedIdentityClient(&ManagedIdentityCredentialOptions{
ID: id,
})
if err != nil {
t.Fatal(err)
}
err = test.createRequest(client)
if err != nil {
t.Fatal(err)
}
for _, msg := range msgs {
if strings.Contains(msg, test.name) && strings.Contains(msg, "user-assigned") {
return
}
}
t.Fatalf("expected warning about user-assigned ID, got:\n%s", strings.Join(msgs, "\n"))
})
}
}
}
71 changes: 71 additions & 0 deletions sdk/azidentity/managed_identity_credential_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,77 @@ func TestManagedIdentityCredential_AzureArc(t *testing.T) {
testGetTokenSuccess(t, cred)
}

func TestManagedIdentityCredential_AzureArcErrors(t *testing.T) {
for k, v := range map[string]string{
arcIMDSEndpoint: "https://localhost",
identityEndpoint: "https://localhost",
} {
t.Setenv(k, v)
}

for _, test := range []struct {
challenge, name string
statusCode int
}{
{name: "no challenge", statusCode: http.StatusUnauthorized},
{name: "malformed challenge", challenge: "Basic realm", statusCode: http.StatusUnauthorized},
{name: "unexpected status code", statusCode: http.StatusOK},
} {
t.Run(test.name, func(t *testing.T) {
srv, close := mock.NewServer(mock.WithTransformAllRequestsToTestServerUrl())
defer close()
srv.AppendResponse(
mock.WithHeader("WWW-Authenticate", test.challenge),
mock.WithStatusCode(test.statusCode),
)
cred, err := NewManagedIdentityCredential(&ManagedIdentityCredentialOptions{
ClientOptions: azcore.ClientOptions{Transport: srv},
})
if err != nil {
t.Fatal(err)
}
_, err = cred.GetToken(context.Background(), testTRO)
if err == nil {
t.Fatal("expected an error")
}
})
}
t.Run("failed to get key", func(t *testing.T) {
srv, close := mock.NewServer(mock.WithTransformAllRequestsToTestServerUrl())
defer close()
srv.SetError(fmt.Errorf("it didn't work"))
cred, err := NewManagedIdentityCredential(&ManagedIdentityCredentialOptions{
ClientOptions: azcore.ClientOptions{
Retry: policy.RetryOptions{MaxRetries: -1},
Transport: srv,
},
})
if err != nil {
t.Fatal(err)
}
_, err = cred.GetToken(context.Background(), testTRO)
if err == nil {
t.Fatal("expected an error")
}
})
t.Run("no key file", func(t *testing.T) {
srv, close := mock.NewServer(mock.WithTransformAllRequestsToTestServerUrl())
defer close()
srv.AppendResponse(
mock.WithHeader("WWW-Authenticate", "Basic realm="+filepath.Join(t.TempDir(), t.Name())),
mock.WithStatusCode(http.StatusUnauthorized),
)
cred, err := NewManagedIdentityCredential(&ManagedIdentityCredentialOptions{ClientOptions: azcore.ClientOptions{Transport: srv}})
if err != nil {
t.Fatal(err)
}
_, err = cred.GetToken(context.Background(), testTRO)
if err == nil {
t.Fatal("expected an error")
}
})
}

func TestManagedIdentityCredential_CloudShell(t *testing.T) {
validateReq := func(req *http.Request) *http.Response {
err := req.ParseForm()
Expand Down
23 changes: 23 additions & 0 deletions sdk/azidentity/workload_identity_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,10 @@ func TestWorkloadIdentityCredential(t *testing.T) {
t.Fatal(err)
}
testGetTokenSuccess(t, cred)
_, err = cred.GetToken(context.Background(), policy.TokenRequestOptions{Scopes: []string{"scope"}})
if err != nil {
t.Fatal(err)
}
}

func TestWorkloadIdentityCredential_Expiration(t *testing.T) {
Expand Down Expand Up @@ -186,6 +190,25 @@ func TestTestWorkloadIdentityCredential_IncompleteConfig(t *testing.T) {
}
}

func TestWorkloadIdentityCredential_NoFile(t *testing.T) {
for k, v := range map[string]string{
azureClientID: fakeClientID,
azureFederatedTokenFile: filepath.Join(t.TempDir(), t.Name()),
azureTenantID: fakeTenantID,
} {
t.Setenv(k, v)
}
cred, err := NewWorkloadIdentityCredential(&WorkloadIdentityCredentialOptions{
ClientOptions: policy.ClientOptions{Transport: &mockSTS{}},
})
if err != nil {
t.Fatal(err)
}
if _, err = cred.GetToken(context.Background(), testTRO); err == nil {
t.Fatal("expected an error")
}
}

func TestWorkloadIdentityCredential_Options(t *testing.T) {
clientID := "not-" + fakeClientID
tenantID := "not-" + fakeTenantID
Expand Down

0 comments on commit 514985a

Please sign in to comment.