diff --git a/sdk/azidentity/client_assertion_credential_test.go b/sdk/azidentity/client_assertion_credential_test.go index 497f48eb2c7f..12d5a8d2c162 100644 --- a/sdk/azidentity/client_assertion_credential_test.go +++ b/sdk/azidentity/client_assertion_credential_test.go @@ -65,6 +65,13 @@ func TestClientAssertionCredentialCallbackError(t *testing.T) { } } +func TestClientAssertionCredentialNilCallback(t *testing.T) { + _, err := NewClientAssertionCredential(fakeTenantID, fakeClientID, nil, nil) + if err == nil { + t.Fatal("expected an error") + } +} + func TestClientAssertionCredential_Live(t *testing.T) { data, err := os.ReadFile(liveSP.pemPath) if err != nil { diff --git a/sdk/azidentity/environment_credential_test.go b/sdk/azidentity/environment_credential_test.go index 3d4b7d2d1e5b..4282615c4791 100644 --- a/sdk/azidentity/environment_credential_test.go +++ b/sdk/azidentity/environment_credential_test.go @@ -10,6 +10,7 @@ import ( "context" "fmt" "os" + "path/filepath" "reflect" "strings" "testing" @@ -93,6 +94,30 @@ func TestEnvironmentCredential_ClientSecretSet(t *testing.T) { } } +func TestEnvironmentCredential_CertificateErrors(t *testing.T) { + resetEnvironmentVarsForTest() + for _, test := range []struct { + name, path string + }{ + {"file doesn't exist", filepath.Join(t.TempDir(), t.Name())}, + {"invalid file", "testdata/certificate-wrong-key.pem"}, + } { + t.Run(test.name, func(t *testing.T) { + for k, v := range map[string]string{ + azureClientID: fakeClientID, + azureClientCertificatePath: test.path, + azureTenantID: fakeTenantID, + } { + t.Setenv(k, v) + _, err := NewEnvironmentCredential(nil) + if err == nil { + t.Fatal("expected an error") + } + } + }) + } +} + func TestEnvironmentCredential_ClientCertificatePathSet(t *testing.T) { resetEnvironmentVarsForTest() err := os.Setenv(azureTenantID, fakeTenantID) diff --git a/sdk/azidentity/managed_identity_client_test.go b/sdk/azidentity/managed_identity_client_test.go index aa604b266269..f85b3f1f4d70 100644 --- a/sdk/azidentity/managed_identity_client_test.go +++ b/sdk/azidentity/managed_identity_client_test.go @@ -15,6 +15,7 @@ import ( "github.com/Azure/azure-sdk-for-go/sdk/azcore" "github.com/Azure/azure-sdk-for-go/sdk/azcore/policy" + "github.com/Azure/azure-sdk-for-go/sdk/internal/log" ) type userAgentValidatingPolicy struct { @@ -73,3 +74,63 @@ func TestManagedIdentityClient_ApplicationID(t *testing.T) { t.Fatal(err) } } + +func TestManagedIdentityClient_UserAssignedIDWarning(t *testing.T) { + for _, test := range []struct { + name string + createRequest func(*managedIdentityClient) error + }{ + { + name: "Azure Arc", + createRequest: func(client *managedIdentityClient) error { + _, err := client.createAzureArcAuthRequest(context.Background(), client.id, []string{liveTestScope}, "key") + return err + }, + }, + { + name: "Cloud Shell", + createRequest: func(client *managedIdentityClient) error { + _, err := client.createCloudShellAuthRequest(context.Background(), client.id, []string{liveTestScope}) + return err + }, + }, + { + name: "Service Fabric", + createRequest: func(client *managedIdentityClient) error { + _, err := client.createServiceFabricAuthRequest(context.Background(), client.id, []string{liveTestScope}) + return err + }, + }, + } { + for _, id := range []ManagedIDKind{ClientID(fakeClientID), ResourceID(fakeResourceID)} { + s := "-ClientID" + if id.String() == fakeResourceID { + s = "-ResourceID" + } + t.Run(test.name+s, func(t *testing.T) { + msgs := []string{} + log.SetListener(func(event log.Event, msg string) { + if event == EventAuthentication { + msgs = append(msgs, msg) + } + }) + client, err := newManagedIdentityClient(&ManagedIdentityCredentialOptions{ + ID: id, + }) + if err != nil { + t.Fatal(err) + } + err = test.createRequest(client) + if err != nil { + t.Fatal(err) + } + for _, msg := range msgs { + if strings.Contains(msg, test.name) && strings.Contains(msg, "user-assigned") { + return + } + } + t.Fatalf("expected warning about user-assigned ID, got:\n%s", strings.Join(msgs, "\n")) + }) + } + } +} diff --git a/sdk/azidentity/managed_identity_credential_test.go b/sdk/azidentity/managed_identity_credential_test.go index 64c6013611a3..e939a6946765 100644 --- a/sdk/azidentity/managed_identity_credential_test.go +++ b/sdk/azidentity/managed_identity_credential_test.go @@ -85,6 +85,77 @@ func TestManagedIdentityCredential_AzureArc(t *testing.T) { testGetTokenSuccess(t, cred) } +func TestManagedIdentityCredential_AzureArcErrors(t *testing.T) { + for k, v := range map[string]string{ + arcIMDSEndpoint: "https://localhost", + identityEndpoint: "https://localhost", + } { + t.Setenv(k, v) + } + + for _, test := range []struct { + challenge, name string + statusCode int + }{ + {name: "no challenge", statusCode: http.StatusUnauthorized}, + {name: "malformed challenge", challenge: "Basic realm", statusCode: http.StatusUnauthorized}, + {name: "unexpected status code", statusCode: http.StatusOK}, + } { + t.Run(test.name, func(t *testing.T) { + srv, close := mock.NewServer(mock.WithTransformAllRequestsToTestServerUrl()) + defer close() + srv.AppendResponse( + mock.WithHeader("WWW-Authenticate", test.challenge), + mock.WithStatusCode(test.statusCode), + ) + cred, err := NewManagedIdentityCredential(&ManagedIdentityCredentialOptions{ + ClientOptions: azcore.ClientOptions{Transport: srv}, + }) + if err != nil { + t.Fatal(err) + } + _, err = cred.GetToken(context.Background(), testTRO) + if err == nil { + t.Fatal("expected an error") + } + }) + } + t.Run("failed to get key", func(t *testing.T) { + srv, close := mock.NewServer(mock.WithTransformAllRequestsToTestServerUrl()) + defer close() + srv.SetError(fmt.Errorf("it didn't work")) + cred, err := NewManagedIdentityCredential(&ManagedIdentityCredentialOptions{ + ClientOptions: azcore.ClientOptions{ + Retry: policy.RetryOptions{MaxRetries: -1}, + Transport: srv, + }, + }) + if err != nil { + t.Fatal(err) + } + _, err = cred.GetToken(context.Background(), testTRO) + if err == nil { + t.Fatal("expected an error") + } + }) + t.Run("no key file", func(t *testing.T) { + srv, close := mock.NewServer(mock.WithTransformAllRequestsToTestServerUrl()) + defer close() + srv.AppendResponse( + mock.WithHeader("WWW-Authenticate", "Basic realm="+filepath.Join(t.TempDir(), t.Name())), + mock.WithStatusCode(http.StatusUnauthorized), + ) + cred, err := NewManagedIdentityCredential(&ManagedIdentityCredentialOptions{ClientOptions: azcore.ClientOptions{Transport: srv}}) + if err != nil { + t.Fatal(err) + } + _, err = cred.GetToken(context.Background(), testTRO) + if err == nil { + t.Fatal("expected an error") + } + }) +} + func TestManagedIdentityCredential_CloudShell(t *testing.T) { validateReq := func(req *http.Request) *http.Response { err := req.ParseForm() diff --git a/sdk/azidentity/workload_identity_test.go b/sdk/azidentity/workload_identity_test.go index 20ba3fde5b62..2557a43893d9 100644 --- a/sdk/azidentity/workload_identity_test.go +++ b/sdk/azidentity/workload_identity_test.go @@ -119,6 +119,10 @@ func TestWorkloadIdentityCredential(t *testing.T) { t.Fatal(err) } testGetTokenSuccess(t, cred) + _, err = cred.GetToken(context.Background(), policy.TokenRequestOptions{Scopes: []string{"scope"}}) + if err != nil { + t.Fatal(err) + } } func TestWorkloadIdentityCredential_Expiration(t *testing.T) { @@ -186,6 +190,25 @@ func TestTestWorkloadIdentityCredential_IncompleteConfig(t *testing.T) { } } +func TestWorkloadIdentityCredential_NoFile(t *testing.T) { + for k, v := range map[string]string{ + azureClientID: fakeClientID, + azureFederatedTokenFile: filepath.Join(t.TempDir(), t.Name()), + azureTenantID: fakeTenantID, + } { + t.Setenv(k, v) + } + cred, err := NewWorkloadIdentityCredential(&WorkloadIdentityCredentialOptions{ + ClientOptions: policy.ClientOptions{Transport: &mockSTS{}}, + }) + if err != nil { + t.Fatal(err) + } + if _, err = cred.GetToken(context.Background(), testTRO); err == nil { + t.Fatal("expected an error") + } +} + func TestWorkloadIdentityCredential_Options(t *testing.T) { clientID := "not-" + fakeClientID tenantID := "not-" + fakeTenantID