Skip to content

Commit

Permalink
Fix Azure join method throttling (#50251)
Browse files Browse the repository at this point in the history
  • Loading branch information
bernardjkim committed Jan 9, 2025
1 parent 8d0038d commit 64f0874
Show file tree
Hide file tree
Showing 3 changed files with 442 additions and 78 deletions.
2 changes: 1 addition & 1 deletion lib/auth/bot_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -568,7 +568,7 @@ func TestRegisterBot_RemoteAddr(t *testing.T) {
rsID := vmResourceID(subID, resourceGroup, "test-vm")
vmID := "vmID"

accessToken, err := makeToken(rsID, a.clock.Now())
accessToken, err := makeToken(rsID, "", a.clock.Now())
require.NoError(t, err)

// add token to auth server
Expand Down
139 changes: 111 additions & 28 deletions lib/auth/join_azure.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,17 +19,21 @@
package auth

import (
"cmp"
"context"
"crypto/x509"
"encoding/base64"
"encoding/pem"
"log/slog"
"net/url"
"slices"
"strings"
"time"

"github.com/Azure/azure-sdk-for-go/sdk/azcore"
"github.com/Azure/azure-sdk-for-go/sdk/azcore/arm"
armpolicy "github.com/Azure/azure-sdk-for-go/sdk/azcore/arm/policy"
"github.com/Azure/azure-sdk-for-go/sdk/azcore/policy"
"github.com/coreos/go-oidc"
"github.com/digitorus/pkcs7"
"github.com/go-jose/go-jose/v3/jwt"
Expand All @@ -38,12 +42,20 @@ import (

"github.com/gravitational/teleport/api/client"
"github.com/gravitational/teleport/api/client/proto"
workloadidentityv1pb "github.com/gravitational/teleport/api/gen/proto/go/teleport/workloadidentity/v1"
"github.com/gravitational/teleport/api/types"
"github.com/gravitational/teleport/lib/cloud/azure"
"github.com/gravitational/teleport/lib/utils"
)

const azureAccessTokenAudience = "https://management.azure.com/"
const (
azureAccessTokenAudience = "https://management.azure.com/"

// azureUserAgent specifies the Azure User-Agent identification for telemetry.
azureUserAgent = "teleport"
// azureVirtualMachine specifies the Azure virtual machine resource type.
azureVirtualMachine = "virtualMachines"
)

// Structs for unmarshaling attested data. Schema can be found at
// https://learn.microsoft.com/en-us/azure/virtual-machines/linux/instance-metadata-service?tabs=linux#response-2
Expand Down Expand Up @@ -76,9 +88,23 @@ type attestedData struct {

type accessTokenClaims struct {
jwt.Claims
ResourceID string `json:"xms_mirid"`
TenantID string `json:"tid"`
Version string `json:"ver"`
TenantID string `json:"tid"`
Version string `json:"ver"`

// Azure JWT tokens include two optional claims that can be used to validate
// the subscription and resource group of a joining node. These claims hold
// different values depending on the assigned Managed Identity of the Azure VM:
// - xms_mirid:
// - For System-Assigned Identity it represents the resource id of the VM.
// - For User-Assigned Identity it represents the resource id of the user-assigned identity.
// - xms_az_rid:
// - For System-Assigned Identity this claim is omitted.
// - For User-Assigned Identity it represents the resource id of the VM.
//
// More details at: https://learn.microsoft.com/en-us/answers/questions/1282788/existence-of-xms-az-rid-field-in-activity-logs-of

ManangedIdentityResourceID string `json:"xms_mirid"`
AzureResourceID string `json:"xms_az_rid"`
}

type azureVerifyTokenFunc func(ctx context.Context, rawIDToken string) (*accessTokenClaims, error)
Expand Down Expand Up @@ -144,7 +170,16 @@ func (cfg *azureRegisterConfig) CheckAndSetDefaults(ctx context.Context) error {
}
if cfg.getVMClient == nil {
cfg.getVMClient = func(subscriptionID string, token *azure.StaticCredential) (azure.VirtualMachinesClient, error) {
client, err := azure.NewVirtualMachinesClient(subscriptionID, token, nil)
// The User-Agent is added for debugging purposes. It helps identify
// and isolate teleport traffic.
opts := &armpolicy.ClientOptions{
ClientOptions: policy.ClientOptions{
Telemetry: policy.TelemetryOptions{
ApplicationID: azureUserAgent,
},
},
}
client, err := azure.NewVirtualMachinesClient(subscriptionID, token, opts)
return client, trace.Wrap(err)
}
}
Expand Down Expand Up @@ -210,8 +245,16 @@ func parseAndVerifyAttestedData(ctx context.Context, adBytes []byte, challenge s
}

// verifyVMIdentity verifies that the provided access token came from the
// correct Azure VM.
func verifyVMIdentity(ctx context.Context, cfg *azureRegisterConfig, accessToken, subscriptionID, vmID string, requestStart time.Time) (*azure.VirtualMachine, error) {
// correct Azure VM. Returns the Aure join attributes
func verifyVMIdentity(
ctx context.Context,
cfg *azureRegisterConfig,
accessToken,
subscriptionID,
vmID string,
requestStart time.Time,
logger *slog.Logger,
) (joinAttrs *workloadidentityv1pb.JoinAttrsAzure, err error) {
tokenClaims, err := cfg.verify(ctx, accessToken)
if err != nil {
return nil, trace.Wrap(err)
Expand Down Expand Up @@ -239,6 +282,20 @@ func verifyVMIdentity(ctx context.Context, cfg *azureRegisterConfig, accessToken
return nil, trace.Wrap(err)
}

// Listing all VMs in an Azure subscription during the verification process
// is problematic when there are a large number of VMs in an Azure subscription.
// In some cases this can lead to throttling due to Azure API rate limits.
// To address the issue, the verification process will first attempt to
// parse required VM identifiers from the token claims. If this method fails,
// fallback to the original method of listing VMs and parsing the VM identifiers
// from the VM resource.
vmSubscription, vmResourceGroup, err := claimsToIdentifiers(tokenClaims)
if err == nil {
return azureJoinToAttrs(vmSubscription, vmResourceGroup), nil
}
logger.WarnContext(ctx, "Failed to parse VM identifiers from claims. Retrying with Azure VM API.",
"error", err)

tokenCredential := azure.NewStaticCredential(azcore.AccessToken{
Token: accessToken,
ExpiresOn: tokenClaims.Expiry.Time(),
Expand All @@ -248,7 +305,7 @@ func verifyVMIdentity(ctx context.Context, cfg *azureRegisterConfig, accessToken
return nil, trace.Wrap(err)
}

resourceID, err := arm.ParseResourceID(tokenClaims.ResourceID)
resourceID, err := arm.ParseResourceID(tokenClaims.ManangedIdentityResourceID)
if err != nil {
return nil, trace.Wrap(err)
}
Expand All @@ -257,8 +314,8 @@ func verifyVMIdentity(ctx context.Context, cfg *azureRegisterConfig, accessToken

// If the token is from the system-assigned managed identity, the resource ID
// is for the VM itself and we can use it to look up the VM.
if slices.Contains(resourceID.ResourceType.Types, "virtualMachines") {
vm, err = vmClient.Get(ctx, tokenClaims.ResourceID)
if slices.Contains(resourceID.ResourceType.Types, azureVirtualMachine) {
vm, err = vmClient.Get(ctx, tokenClaims.ManangedIdentityResourceID)
if err != nil {
return nil, trace.Wrap(err)
}
Expand All @@ -277,21 +334,35 @@ func verifyVMIdentity(ctx context.Context, cfg *azureRegisterConfig, accessToken
return nil, trace.Wrap(err)
}
}
return azureJoinToAttrs(vm.Subscription, vm.ResourceGroup), nil
}

return vm, nil
// claimsToIdentifiers returns the vm identifiers from the provided claims.
func claimsToIdentifiers(tokenClaims *accessTokenClaims) (subscriptionID, resourceGroupID string, err error) {
// xms_az_rid claim is omitted when the VM is assigned a System-Assigned Identity.
// The xms_mirid claim should be used instead.
rid := cmp.Or(tokenClaims.AzureResourceID, tokenClaims.ManangedIdentityResourceID)
resourceID, err := arm.ParseResourceID(rid)
if err != nil {
return "", "", trace.Wrap(err, "failed to parse resource id from claims")
}
if !slices.Contains(resourceID.ResourceType.Types, azureVirtualMachine) {
return "", "", trace.BadParameter("unexpected resource type: %q", resourceID.ResourceType.Type)
}
return resourceID.SubscriptionID, resourceID.ResourceGroupName, nil
}

func checkAzureAllowRules(vm *azure.VirtualMachine, token string, allowRules []*types.ProvisionTokenSpecV2Azure_Rule) error {
for _, rule := range allowRules {
if rule.Subscription != vm.Subscription {
func checkAzureAllowRules(vmID string, attrs *workloadidentityv1pb.JoinAttrsAzure, token *types.ProvisionTokenV2) error {
for _, rule := range token.Spec.Azure.Allow {
if rule.Subscription != attrs.Subscription {
continue
}
if !azureResourceGroupIsAllowed(rule.ResourceGroups, vm.ResourceGroup) {
if !azureResourceGroupIsAllowed(rule.ResourceGroups, attrs.ResourceGroup) {
continue
}
return nil
}
return trace.AccessDenied("instance %v did not match any allow rules in token %v", vm.Name, token)
return trace.AccessDenied("instance %v did not match any allow rules in token %v", vmID, token.GetName())
}
func azureResourceGroupIsAllowed(allowedResourceGroups []string, vmResourceGroup string) bool {
if len(allowedResourceGroups) == 0 {
Expand All @@ -312,37 +383,48 @@ func azureResourceGroupIsAllowed(allowedResourceGroups []string, vmResourceGroup
return false
}

func (a *Server) checkAzureRequest(ctx context.Context, challenge string, req *proto.RegisterUsingAzureMethodRequest, cfg *azureRegisterConfig) error {
func azureJoinToAttrs(subscriptionID, resourceGroupID string) *workloadidentityv1pb.JoinAttrsAzure {
return &workloadidentityv1pb.JoinAttrsAzure{
Subscription: subscriptionID,
ResourceGroup: resourceGroupID,
}
}

func (a *Server) checkAzureRequest(
ctx context.Context,
challenge string,
req *proto.RegisterUsingAzureMethodRequest,
cfg *azureRegisterConfig,
) (*workloadidentityv1pb.JoinAttrsAzure, error) {
requestStart := a.clock.Now()
tokenName := req.RegisterUsingTokenRequest.Token
provisionToken, err := a.GetToken(ctx, tokenName)
if err != nil {
return trace.Wrap(err)
return nil, trace.Wrap(err)
}
if provisionToken.GetJoinMethod() != types.JoinMethodAzure {
return trace.AccessDenied("this token does not support the Azure join method")
return nil, trace.AccessDenied("this token does not support the Azure join method")
}

subID, vmID, err := parseAndVerifyAttestedData(ctx, req.AttestedData, challenge, cfg.certificateAuthorities)
if err != nil {
return trace.Wrap(err)
return nil, trace.Wrap(err)
}

vm, err := verifyVMIdentity(ctx, cfg, req.AccessToken, subID, vmID, requestStart)
attrs, err := verifyVMIdentity(ctx, cfg, req.AccessToken, subID, vmID, requestStart, a.logger)
if err != nil {
return trace.Wrap(err)
return nil, trace.Wrap(err)
}

token, ok := provisionToken.(*types.ProvisionTokenV2)
if !ok {
return trace.BadParameter("azure join method only supports ProvisionTokenV2, '%T' was provided", provisionToken)
return nil, trace.BadParameter("azure join method only supports ProvisionTokenV2, '%T' was provided", provisionToken)
}

if err := checkAzureAllowRules(vm, token.GetName(), token.Spec.Azure.Allow); err != nil {
return trace.Wrap(err)
if err := checkAzureAllowRules(vmID, attrs, token); err != nil {
return attrs, trace.Wrap(err)
}

return nil
return attrs, nil
}

func generateAzureChallenge() (string, error) {
Expand Down Expand Up @@ -399,7 +481,8 @@ func (a *Server) RegisterUsingAzureMethodWithOpts(
return nil, trace.Wrap(err)
}

if err := a.checkAzureRequest(ctx, challenge, req, cfg); err != nil {
_, err = a.checkAzureRequest(ctx, challenge, req, cfg)
if err != nil {
return nil, trace.Wrap(err)
}

Expand Down
Loading

0 comments on commit 64f0874

Please sign in to comment.