From 64f0874f79c538bb29e6e1c6ebf427f7bf45f27b Mon Sep 17 00:00:00 2001 From: Bernard Kim Date: Thu, 9 Jan 2025 12:16:33 -0800 Subject: [PATCH] Fix Azure join method throttling (#50251) --- lib/auth/bot_test.go | 2 +- lib/auth/join_azure.go | 139 ++++++++++--- lib/auth/join_azure_test.go | 379 +++++++++++++++++++++++++++++++----- 3 files changed, 442 insertions(+), 78 deletions(-) diff --git a/lib/auth/bot_test.go b/lib/auth/bot_test.go index 5ff53115b5374..8809cff777d6b 100644 --- a/lib/auth/bot_test.go +++ b/lib/auth/bot_test.go @@ -568,7 +568,7 @@ func TestRegisterBot_RemoteAddr(t *testing.T) { rsID := vmResourceID(subID, resourceGroup, "test-vm") vmID := "vmID" - accessToken, err := makeToken(rsID, a.clock.Now()) + accessToken, err := makeToken(rsID, "", a.clock.Now()) require.NoError(t, err) // add token to auth server diff --git a/lib/auth/join_azure.go b/lib/auth/join_azure.go index 721a53ff2d7fa..8d120044a155f 100644 --- a/lib/auth/join_azure.go +++ b/lib/auth/join_azure.go @@ -19,10 +19,12 @@ package auth import ( + "cmp" "context" "crypto/x509" "encoding/base64" "encoding/pem" + "log/slog" "net/url" "slices" "strings" @@ -30,6 +32,8 @@ import ( "github.com/Azure/azure-sdk-for-go/sdk/azcore" "github.com/Azure/azure-sdk-for-go/sdk/azcore/arm" + armpolicy "github.com/Azure/azure-sdk-for-go/sdk/azcore/arm/policy" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/policy" "github.com/coreos/go-oidc" "github.com/digitorus/pkcs7" "github.com/go-jose/go-jose/v3/jwt" @@ -38,12 +42,20 @@ import ( "github.com/gravitational/teleport/api/client" "github.com/gravitational/teleport/api/client/proto" + workloadidentityv1pb "github.com/gravitational/teleport/api/gen/proto/go/teleport/workloadidentity/v1" "github.com/gravitational/teleport/api/types" "github.com/gravitational/teleport/lib/cloud/azure" "github.com/gravitational/teleport/lib/utils" ) -const azureAccessTokenAudience = "https://management.azure.com/" +const ( + azureAccessTokenAudience = "https://management.azure.com/" + + // azureUserAgent specifies the Azure User-Agent identification for telemetry. + azureUserAgent = "teleport" + // azureVirtualMachine specifies the Azure virtual machine resource type. + azureVirtualMachine = "virtualMachines" +) // Structs for unmarshaling attested data. Schema can be found at // https://learn.microsoft.com/en-us/azure/virtual-machines/linux/instance-metadata-service?tabs=linux#response-2 @@ -76,9 +88,23 @@ type attestedData struct { type accessTokenClaims struct { jwt.Claims - ResourceID string `json:"xms_mirid"` - TenantID string `json:"tid"` - Version string `json:"ver"` + TenantID string `json:"tid"` + Version string `json:"ver"` + + // Azure JWT tokens include two optional claims that can be used to validate + // the subscription and resource group of a joining node. These claims hold + // different values depending on the assigned Managed Identity of the Azure VM: + // - xms_mirid: + // - For System-Assigned Identity it represents the resource id of the VM. + // - For User-Assigned Identity it represents the resource id of the user-assigned identity. + // - xms_az_rid: + // - For System-Assigned Identity this claim is omitted. + // - For User-Assigned Identity it represents the resource id of the VM. + // + // More details at: https://learn.microsoft.com/en-us/answers/questions/1282788/existence-of-xms-az-rid-field-in-activity-logs-of + + ManangedIdentityResourceID string `json:"xms_mirid"` + AzureResourceID string `json:"xms_az_rid"` } type azureVerifyTokenFunc func(ctx context.Context, rawIDToken string) (*accessTokenClaims, error) @@ -144,7 +170,16 @@ func (cfg *azureRegisterConfig) CheckAndSetDefaults(ctx context.Context) error { } if cfg.getVMClient == nil { cfg.getVMClient = func(subscriptionID string, token *azure.StaticCredential) (azure.VirtualMachinesClient, error) { - client, err := azure.NewVirtualMachinesClient(subscriptionID, token, nil) + // The User-Agent is added for debugging purposes. It helps identify + // and isolate teleport traffic. + opts := &armpolicy.ClientOptions{ + ClientOptions: policy.ClientOptions{ + Telemetry: policy.TelemetryOptions{ + ApplicationID: azureUserAgent, + }, + }, + } + client, err := azure.NewVirtualMachinesClient(subscriptionID, token, opts) return client, trace.Wrap(err) } } @@ -210,8 +245,16 @@ func parseAndVerifyAttestedData(ctx context.Context, adBytes []byte, challenge s } // verifyVMIdentity verifies that the provided access token came from the -// correct Azure VM. -func verifyVMIdentity(ctx context.Context, cfg *azureRegisterConfig, accessToken, subscriptionID, vmID string, requestStart time.Time) (*azure.VirtualMachine, error) { +// correct Azure VM. Returns the Aure join attributes +func verifyVMIdentity( + ctx context.Context, + cfg *azureRegisterConfig, + accessToken, + subscriptionID, + vmID string, + requestStart time.Time, + logger *slog.Logger, +) (joinAttrs *workloadidentityv1pb.JoinAttrsAzure, err error) { tokenClaims, err := cfg.verify(ctx, accessToken) if err != nil { return nil, trace.Wrap(err) @@ -239,6 +282,20 @@ func verifyVMIdentity(ctx context.Context, cfg *azureRegisterConfig, accessToken return nil, trace.Wrap(err) } + // Listing all VMs in an Azure subscription during the verification process + // is problematic when there are a large number of VMs in an Azure subscription. + // In some cases this can lead to throttling due to Azure API rate limits. + // To address the issue, the verification process will first attempt to + // parse required VM identifiers from the token claims. If this method fails, + // fallback to the original method of listing VMs and parsing the VM identifiers + // from the VM resource. + vmSubscription, vmResourceGroup, err := claimsToIdentifiers(tokenClaims) + if err == nil { + return azureJoinToAttrs(vmSubscription, vmResourceGroup), nil + } + logger.WarnContext(ctx, "Failed to parse VM identifiers from claims. Retrying with Azure VM API.", + "error", err) + tokenCredential := azure.NewStaticCredential(azcore.AccessToken{ Token: accessToken, ExpiresOn: tokenClaims.Expiry.Time(), @@ -248,7 +305,7 @@ func verifyVMIdentity(ctx context.Context, cfg *azureRegisterConfig, accessToken return nil, trace.Wrap(err) } - resourceID, err := arm.ParseResourceID(tokenClaims.ResourceID) + resourceID, err := arm.ParseResourceID(tokenClaims.ManangedIdentityResourceID) if err != nil { return nil, trace.Wrap(err) } @@ -257,8 +314,8 @@ func verifyVMIdentity(ctx context.Context, cfg *azureRegisterConfig, accessToken // If the token is from the system-assigned managed identity, the resource ID // is for the VM itself and we can use it to look up the VM. - if slices.Contains(resourceID.ResourceType.Types, "virtualMachines") { - vm, err = vmClient.Get(ctx, tokenClaims.ResourceID) + if slices.Contains(resourceID.ResourceType.Types, azureVirtualMachine) { + vm, err = vmClient.Get(ctx, tokenClaims.ManangedIdentityResourceID) if err != nil { return nil, trace.Wrap(err) } @@ -277,21 +334,35 @@ func verifyVMIdentity(ctx context.Context, cfg *azureRegisterConfig, accessToken return nil, trace.Wrap(err) } } + return azureJoinToAttrs(vm.Subscription, vm.ResourceGroup), nil +} - return vm, nil +// claimsToIdentifiers returns the vm identifiers from the provided claims. +func claimsToIdentifiers(tokenClaims *accessTokenClaims) (subscriptionID, resourceGroupID string, err error) { + // xms_az_rid claim is omitted when the VM is assigned a System-Assigned Identity. + // The xms_mirid claim should be used instead. + rid := cmp.Or(tokenClaims.AzureResourceID, tokenClaims.ManangedIdentityResourceID) + resourceID, err := arm.ParseResourceID(rid) + if err != nil { + return "", "", trace.Wrap(err, "failed to parse resource id from claims") + } + if !slices.Contains(resourceID.ResourceType.Types, azureVirtualMachine) { + return "", "", trace.BadParameter("unexpected resource type: %q", resourceID.ResourceType.Type) + } + return resourceID.SubscriptionID, resourceID.ResourceGroupName, nil } -func checkAzureAllowRules(vm *azure.VirtualMachine, token string, allowRules []*types.ProvisionTokenSpecV2Azure_Rule) error { - for _, rule := range allowRules { - if rule.Subscription != vm.Subscription { +func checkAzureAllowRules(vmID string, attrs *workloadidentityv1pb.JoinAttrsAzure, token *types.ProvisionTokenV2) error { + for _, rule := range token.Spec.Azure.Allow { + if rule.Subscription != attrs.Subscription { continue } - if !azureResourceGroupIsAllowed(rule.ResourceGroups, vm.ResourceGroup) { + if !azureResourceGroupIsAllowed(rule.ResourceGroups, attrs.ResourceGroup) { continue } return nil } - return trace.AccessDenied("instance %v did not match any allow rules in token %v", vm.Name, token) + return trace.AccessDenied("instance %v did not match any allow rules in token %v", vmID, token.GetName()) } func azureResourceGroupIsAllowed(allowedResourceGroups []string, vmResourceGroup string) bool { if len(allowedResourceGroups) == 0 { @@ -312,37 +383,48 @@ func azureResourceGroupIsAllowed(allowedResourceGroups []string, vmResourceGroup return false } -func (a *Server) checkAzureRequest(ctx context.Context, challenge string, req *proto.RegisterUsingAzureMethodRequest, cfg *azureRegisterConfig) error { +func azureJoinToAttrs(subscriptionID, resourceGroupID string) *workloadidentityv1pb.JoinAttrsAzure { + return &workloadidentityv1pb.JoinAttrsAzure{ + Subscription: subscriptionID, + ResourceGroup: resourceGroupID, + } +} + +func (a *Server) checkAzureRequest( + ctx context.Context, + challenge string, + req *proto.RegisterUsingAzureMethodRequest, + cfg *azureRegisterConfig, +) (*workloadidentityv1pb.JoinAttrsAzure, error) { requestStart := a.clock.Now() tokenName := req.RegisterUsingTokenRequest.Token provisionToken, err := a.GetToken(ctx, tokenName) if err != nil { - return trace.Wrap(err) + return nil, trace.Wrap(err) } if provisionToken.GetJoinMethod() != types.JoinMethodAzure { - return trace.AccessDenied("this token does not support the Azure join method") + return nil, trace.AccessDenied("this token does not support the Azure join method") } subID, vmID, err := parseAndVerifyAttestedData(ctx, req.AttestedData, challenge, cfg.certificateAuthorities) if err != nil { - return trace.Wrap(err) + return nil, trace.Wrap(err) } - vm, err := verifyVMIdentity(ctx, cfg, req.AccessToken, subID, vmID, requestStart) + attrs, err := verifyVMIdentity(ctx, cfg, req.AccessToken, subID, vmID, requestStart, a.logger) if err != nil { - return trace.Wrap(err) + return nil, trace.Wrap(err) } token, ok := provisionToken.(*types.ProvisionTokenV2) if !ok { - return trace.BadParameter("azure join method only supports ProvisionTokenV2, '%T' was provided", provisionToken) + return nil, trace.BadParameter("azure join method only supports ProvisionTokenV2, '%T' was provided", provisionToken) } - - if err := checkAzureAllowRules(vm, token.GetName(), token.Spec.Azure.Allow); err != nil { - return trace.Wrap(err) + if err := checkAzureAllowRules(vmID, attrs, token); err != nil { + return attrs, trace.Wrap(err) } - return nil + return attrs, nil } func generateAzureChallenge() (string, error) { @@ -399,7 +481,8 @@ func (a *Server) RegisterUsingAzureMethodWithOpts( return nil, trace.Wrap(err) } - if err := a.checkAzureRequest(ctx, challenge, req, cfg); err != nil { + _, err = a.checkAzureRequest(ctx, challenge, req, cfg) + if err != nil { return nil, trace.Wrap(err) } diff --git a/lib/auth/join_azure_test.go b/lib/auth/join_azure_test.go index 0944e1ac9ed48..c7cc7c5b18954 100644 --- a/lib/auth/join_azure_test.go +++ b/lib/auth/join_azure_test.go @@ -104,12 +104,16 @@ func withChallengeAzure(challenge string) azureChallengeResponseOption { } func vmResourceID(subscription, resourceGroup, name string) string { - return resourceID("virtualMachines", subscription, resourceGroup, name) + return resourceID("Microsoft.Compute/virtualMachines", subscription, resourceGroup, name) +} + +func identityResourceID(subscription, resourceGroup, name string) string { + return resourceID("Microsoft.ManagedIdentity/userAssignedIdentities", subscription, resourceGroup, name) } func resourceID(resourceType, subscription, resourceGroup, name string) string { return fmt.Sprintf( - "/subscriptions/%v/resourcegroups/%v/providers/Microsoft.Compute/%v/%v", + "/subscriptions/%v/resourcegroups/%v/providers/%v/%v", subscription, resourceGroup, resourceType, name, ) } @@ -131,7 +135,7 @@ func mockVerifyToken(err error) azureVerifyTokenFunc { } } -func makeToken(resourceID string, issueTime time.Time) (string, error) { +func makeToken(managedIdentityResourceID, azureResourceID string, issueTime time.Time) (string, error) { sig, err := jose.NewSigner(jose.SigningKey{ Algorithm: jose.HS256, Key: []byte("test-key"), @@ -149,9 +153,10 @@ func makeToken(resourceID string, issueTime time.Time) (string, error) { Expiry: jwt.NewNumericDate(issueTime.Add(time.Minute)), ID: "id", }, - ResourceID: resourceID, - TenantID: "test-tenant-id", - Version: "1.0", + ManangedIdentityResourceID: managedIdentityResourceID, + AzureResourceID: azureResourceID, + TenantID: "test-tenant-id", + Version: "1.0", } raw, err := jwt.Signed(sig).Claims(claims).CompactSerialize() if err != nil { @@ -189,28 +194,28 @@ func TestAuth_RegisterUsingAzureMethod(t *testing.T) { isBadParameter := func(t require.TestingT, err error, _ ...any) { require.True(t, trace.IsBadParameter(err), "expected Bad Parameter error, actual error: %v", err) } - isNotFound := func(t require.TestingT, err error, _ ...any) { - require.True(t, trace.IsNotFound(err), "expected Not Found error, actual error: %v", err) - } defaultSubscription := uuid.NewString() defaultResourceGroup := "my-resource-group" - defaultName := "test-vm" + defaultVMName := "test-vm" + defaultIdentityName := "test-id" defaultVMID := "my-vm-id" - defaultResourceID := vmResourceID(defaultSubscription, defaultResourceGroup, defaultName) + defaultVMResourceID := vmResourceID(defaultSubscription, defaultResourceGroup, defaultVMName) + defaultIdentityResourceID := identityResourceID(defaultSubscription, defaultResourceGroup, defaultIdentityName) tests := []struct { - name string - tokenResourceID string - tokenSubscription string - tokenVMID string - requestTokenName string - tokenSpec types.ProvisionTokenSpecV2 - challengeResponseOptions []azureChallengeResponseOption - challengeResponseErr error - certs []*x509.Certificate - verify azureVerifyTokenFunc - assertError require.ErrorAssertionFunc + name string + tokenManagedIdentityResourceID string + tokenAzureResourceID string + tokenSubscription string + tokenVMID string + requestTokenName string + tokenSpec types.ProvisionTokenSpecV2 + challengeResponseOptions []azureChallengeResponseOption + challengeResponseErr error + certs []*x509.Certificate + verify azureVerifyTokenFunc + assertError require.ErrorAssertionFunc }{ { name: "basic passing case", @@ -380,10 +385,11 @@ func TestAuth_RegisterUsingAzureMethod(t *testing.T) { assertError: require.Error, }, { - name: "attested data and access token from different VMs", - requestTokenName: "test-token", - tokenSubscription: defaultSubscription, - tokenVMID: "some-other-vm-id", + name: "attested data and access token from different VMs", + requestTokenName: "test-token", + tokenSubscription: defaultSubscription, + tokenVMID: "some-other-vm-id", + tokenManagedIdentityResourceID: defaultIdentityResourceID, tokenSpec: types.ProvisionTokenSpecV2{ Roles: []types.SystemRole{types.RoleNode}, Azure: &types.ProvisionTokenSpecV2Azure{ @@ -400,11 +406,11 @@ func TestAuth_RegisterUsingAzureMethod(t *testing.T) { assertError: isAccessDenied, }, { - name: "vm not found", - requestTokenName: "test-token", - tokenSubscription: defaultSubscription, - tokenVMID: defaultVMID, - tokenResourceID: vmResourceID(defaultSubscription, "nonexistent-group", defaultName), + name: "vm not found", + requestTokenName: "test-token", + tokenSubscription: defaultSubscription, + tokenVMID: "invalid-id", + tokenManagedIdentityResourceID: identityResourceID(defaultSubscription, defaultResourceGroup, "invalid-vm"), tokenSpec: types.ProvisionTokenSpecV2{ Roles: []types.SystemRole{types.RoleNode}, Azure: &types.ProvisionTokenSpecV2Azure{ @@ -418,14 +424,14 @@ func TestAuth_RegisterUsingAzureMethod(t *testing.T) { }, verify: mockVerifyToken(nil), certs: []*x509.Certificate{tlsConfig.Certificate}, - assertError: isNotFound, + assertError: isAccessDenied, }, { - name: "lookup vm by id", - requestTokenName: "test-token", - tokenSubscription: defaultSubscription, - tokenVMID: defaultVMID, - tokenResourceID: resourceID("some.other.provider", defaultSubscription, defaultResourceGroup, defaultName), + name: "lookup vm by id", + requestTokenName: "test-token", + tokenSubscription: defaultSubscription, + tokenVMID: defaultVMID, + tokenManagedIdentityResourceID: defaultIdentityResourceID, tokenSpec: types.ProvisionTokenSpecV2{ Roles: []types.SystemRole{types.RoleNode}, Azure: &types.ProvisionTokenSpecV2Azure{ @@ -442,11 +448,11 @@ func TestAuth_RegisterUsingAzureMethod(t *testing.T) { assertError: require.NoError, }, { - name: "vm is in a different subscription than the token it provides", - requestTokenName: "test-token", - tokenSubscription: defaultSubscription, - tokenVMID: defaultVMID, - tokenResourceID: resourceID("some.other.provider", "some-other-subscription", defaultResourceGroup, defaultName), + name: "vm is in a different subscription than the token it provides", + requestTokenName: "test-token", + tokenSubscription: defaultSubscription, + tokenVMID: defaultVMID, + tokenManagedIdentityResourceID: identityResourceID("some-other-subscription", defaultResourceGroup, defaultVMName), tokenSpec: types.ProvisionTokenSpecV2{ Roles: []types.SystemRole{types.RoleNode}, Azure: &types.ProvisionTokenSpecV2Azure{ @@ -476,19 +482,19 @@ func TestAuth_RegisterUsingAzureMethod(t *testing.T) { require.NoError(t, a.DeleteToken(ctx, token.GetName())) }) - rsID := tc.tokenResourceID - if rsID == "" { - rsID = vmResourceID(defaultSubscription, defaultResourceGroup, defaultName) + mirID := tc.tokenManagedIdentityResourceID + if mirID == "" { + mirID = vmResourceID(defaultSubscription, defaultResourceGroup, defaultVMName) } - accessToken, err := makeToken(rsID, a.clock.Now()) + accessToken, err := makeToken(mirID, "", a.clock.Now()) require.NoError(t, err) vmClient := &mockAzureVMClient{ vms: map[string]*azure.VirtualMachine{ - defaultResourceID: { - ID: defaultResourceID, - Name: defaultName, + defaultVMResourceID: { + ID: defaultVMResourceID, + Name: defaultVMName, Subscription: defaultSubscription, ResourceGroup: defaultResourceGroup, VMID: defaultVMID, @@ -541,3 +547,278 @@ func TestAuth_RegisterUsingAzureMethod(t *testing.T) { }) } } + +// TestAuth_RegisterUsingAzureClaims tests the Azure join method by verifying +// joining VMs by the token claims rather than from the Azure VM API. +func TestAuth_RegisterUsingAzureClaims(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithCancel(context.Background()) + t.Cleanup(cancel) + + p, err := newTestPack(ctx, t.TempDir()) + require.NoError(t, err) + a := p.a + + sshPrivateKey, sshPublicKey, err := testauthority.New().GenerateKeyPair() + require.NoError(t, err) + + tlsConfig, err := fixtures.LocalTLSConfig() + require.NoError(t, err) + + block, _ := pem.Decode(fixtures.LocalhostKey) + pkey, err := x509.ParsePKCS1PrivateKey(block.Bytes) + require.NoError(t, err) + + tlsPublicKey, err := PrivateKeyToPublicKeyTLS(sshPrivateKey) + require.NoError(t, err) + + isAccessDenied := func(t require.TestingT, err error, _ ...any) { + require.True(t, trace.IsAccessDenied(err), "expected Access Denied error, actual error: %v", err) + } + defaultSubscription := uuid.NewString() + defaultResourceGroup := "my-resource-group" + defaultVMName := "test-vm" + defaultIdentityName := "test-id" + defaultVMID := "my-vm-id" + + tests := []struct { + name string + tokenManagedIdentityResourceID string + tokenAzureResourceID string + tokenSubscription string + tokenVMID string + requestTokenName string + tokenSpec types.ProvisionTokenSpecV2 + challengeResponseOptions []azureChallengeResponseOption + challengeResponseErr error + certs []*x509.Certificate + verify azureVerifyTokenFunc + assertError require.ErrorAssertionFunc + }{ + { + name: "system-managed identity ok", + requestTokenName: "test-token", + tokenSubscription: "system-managed-test", + tokenVMID: defaultVMID, + tokenManagedIdentityResourceID: vmResourceID("system-managed-test", "system-managed-test", defaultVMName), + tokenSpec: types.ProvisionTokenSpecV2{ + Roles: []types.SystemRole{types.RoleNode}, + Azure: &types.ProvisionTokenSpecV2Azure{ + Allow: []*types.ProvisionTokenSpecV2Azure_Rule{ + { + Subscription: "system-managed-test", + ResourceGroups: []string{"system-managed-test"}, + }, + }, + }, + JoinMethod: types.JoinMethodAzure, + }, + verify: mockVerifyToken(nil), + certs: []*x509.Certificate{tlsConfig.Certificate}, + assertError: require.NoError, + }, + { + name: "system-managed identity with wrong subscription", + requestTokenName: "test-token", + tokenSubscription: "system-managed-test", + tokenVMID: defaultVMID, + tokenManagedIdentityResourceID: vmResourceID("system-managed-test", "system-managed-test", defaultVMName), + tokenSpec: types.ProvisionTokenSpecV2{ + Roles: []types.SystemRole{types.RoleNode}, + Azure: &types.ProvisionTokenSpecV2Azure{ + Allow: []*types.ProvisionTokenSpecV2Azure_Rule{ + { + Subscription: defaultSubscription, + ResourceGroups: []string{"system-managed-test"}, + }, + }, + }, + JoinMethod: types.JoinMethodAzure, + }, + verify: mockVerifyToken(nil), + certs: []*x509.Certificate{tlsConfig.Certificate}, + assertError: isAccessDenied, + }, + { + name: "system-managed identity with wrong resource group", + requestTokenName: "test-token", + tokenSubscription: "system-managed-test", + tokenVMID: defaultVMID, + tokenManagedIdentityResourceID: vmResourceID("system-managed-test", "system-managed-test", defaultVMName), + tokenSpec: types.ProvisionTokenSpecV2{ + Roles: []types.SystemRole{types.RoleNode}, + Azure: &types.ProvisionTokenSpecV2Azure{ + Allow: []*types.ProvisionTokenSpecV2Azure_Rule{ + { + Subscription: "system-managed-test", + ResourceGroups: []string{defaultResourceGroup}, + }, + }, + }, + JoinMethod: types.JoinMethodAzure, + }, + verify: mockVerifyToken(nil), + certs: []*x509.Certificate{tlsConfig.Certificate}, + assertError: isAccessDenied, + }, + { + name: "user-managed identity ok", + requestTokenName: "test-token", + tokenSubscription: "user-managed-test", + tokenVMID: defaultVMID, + tokenManagedIdentityResourceID: identityResourceID("user-managed-test", "user-managed-test", defaultIdentityName), + tokenAzureResourceID: vmResourceID("user-managed-test", "user-managed-test", defaultVMName), + tokenSpec: types.ProvisionTokenSpecV2{ + Roles: []types.SystemRole{types.RoleNode}, + Azure: &types.ProvisionTokenSpecV2Azure{ + Allow: []*types.ProvisionTokenSpecV2Azure_Rule{ + { + Subscription: "user-managed-test", + ResourceGroups: []string{"user-managed-test"}, + }, + }, + }, + JoinMethod: types.JoinMethodAzure, + }, + verify: mockVerifyToken(nil), + certs: []*x509.Certificate{tlsConfig.Certificate}, + assertError: require.NoError, + }, + { + name: "user-managed identity with wrong subscription", + requestTokenName: "test-token", + tokenSubscription: "user-managed-test", + tokenVMID: defaultVMID, + tokenManagedIdentityResourceID: identityResourceID("user-managed-test", "user-managed-test", defaultIdentityName), + tokenAzureResourceID: vmResourceID("user-managed-test", "user-managed-test", defaultVMName), + tokenSpec: types.ProvisionTokenSpecV2{ + Roles: []types.SystemRole{types.RoleNode}, + Azure: &types.ProvisionTokenSpecV2Azure{ + Allow: []*types.ProvisionTokenSpecV2Azure_Rule{ + { + Subscription: defaultSubscription, + ResourceGroups: []string{"user-managed-test"}, + }, + }, + }, + JoinMethod: types.JoinMethodAzure, + }, + verify: mockVerifyToken(nil), + certs: []*x509.Certificate{tlsConfig.Certificate}, + assertError: isAccessDenied, + }, + { + name: "user-managed identity with wrong resource group", + requestTokenName: "test-token", + tokenSubscription: "user-managed-test", + tokenVMID: defaultVMID, + tokenManagedIdentityResourceID: identityResourceID("user-managed-test", "user-managed-test", defaultIdentityName), + tokenAzureResourceID: vmResourceID("user-managed-test", "user-managed-test", defaultVMName), + tokenSpec: types.ProvisionTokenSpecV2{ + Roles: []types.SystemRole{types.RoleNode}, + Azure: &types.ProvisionTokenSpecV2Azure{ + Allow: []*types.ProvisionTokenSpecV2Azure_Rule{ + { + Subscription: "user-managed-test", + ResourceGroups: []string{defaultResourceGroup}, + }, + }, + }, + JoinMethod: types.JoinMethodAzure, + }, + verify: mockVerifyToken(nil), + certs: []*x509.Certificate{tlsConfig.Certificate}, + assertError: isAccessDenied, + }, + { + name: "user-managed identity from different subscription", + requestTokenName: "test-token", + tokenSubscription: "user-managed-test", + tokenVMID: defaultVMID, + tokenManagedIdentityResourceID: identityResourceID("invalid-user-managed-test", "invalid-user-managed-test", defaultIdentityName), + tokenAzureResourceID: vmResourceID("user-managed-test", "user-managed-test", defaultVMName), + tokenSpec: types.ProvisionTokenSpecV2{ + Roles: []types.SystemRole{types.RoleNode}, + Azure: &types.ProvisionTokenSpecV2Azure{ + Allow: []*types.ProvisionTokenSpecV2Azure_Rule{ + { + Subscription: "user-managed-test", + ResourceGroups: []string{"user-managed-test"}, + }, + }, + }, + JoinMethod: types.JoinMethodAzure, + }, + verify: mockVerifyToken(nil), + certs: []*x509.Certificate{tlsConfig.Certificate}, + assertError: require.NoError, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + token, err := types.NewProvisionTokenFromSpec( + "test-token", + time.Now().Add(time.Minute), + tc.tokenSpec) + require.NoError(t, err) + require.NoError(t, a.UpsertToken(ctx, token)) + t.Cleanup(func() { + require.NoError(t, a.DeleteToken(ctx, token.GetName())) + }) + + mirID := tc.tokenManagedIdentityResourceID + azrID := tc.tokenAzureResourceID + accessToken, err := makeToken(mirID, azrID, a.clock.Now()) + require.NoError(t, err) + + vmClient := &mockAzureVMClient{ + vms: map[string]*azure.VirtualMachine{}, + } + getVMClient := makeVMClientGetter(map[string]*mockAzureVMClient{ + defaultSubscription: vmClient, + }) + + _, err = a.RegisterUsingAzureMethodWithOpts(context.Background(), func(challenge string) (*proto.RegisterUsingAzureMethodRequest, error) { + cfg := &azureChallengeResponseConfig{Challenge: challenge} + for _, opt := range tc.challengeResponseOptions { + opt(cfg) + } + + ad := attestedData{ + Nonce: cfg.Challenge, + SubscriptionID: tc.tokenSubscription, + ID: tc.tokenVMID, + } + adBytes, err := json.Marshal(&ad) + require.NoError(t, err) + s, err := pkcs7.NewSignedData(adBytes) + require.NoError(t, err) + require.NoError(t, s.AddSigner(tlsConfig.Certificate, pkey, pkcs7.SignerInfoConfig{})) + signature, err := s.Finish() + require.NoError(t, err) + signedAD := signedAttestedData{ + Encoding: "pkcs7", + Signature: base64.StdEncoding.EncodeToString(signature), + } + signedADBytes, err := json.Marshal(&signedAD) + require.NoError(t, err) + + req := &proto.RegisterUsingAzureMethodRequest{ + RegisterUsingTokenRequest: &types.RegisterUsingTokenRequest{ + Token: tc.requestTokenName, + HostID: "test-node", + Role: types.RoleNode, + PublicSSHKey: sshPublicKey, + PublicTLSKey: tlsPublicKey, + }, + AttestedData: signedADBytes, + AccessToken: accessToken, + } + return req, tc.challengeResponseErr + }, withCerts(tc.certs), withVerifyFunc(tc.verify), withVMClientGetter(getVMClient)) + tc.assertError(t, err) + }) + } +}