From bbe5fd6c896aafae8c79b845e65427a1172bdfaa Mon Sep 17 00:00:00 2001 From: Joel Hendrix Date: Wed, 20 Oct 2021 14:57:20 -0700 Subject: [PATCH 1/5] Add ARM-specific bearer token policy Removed support for auxiliary tenants from the runtime version of this policy as this is specific to ARM. --- sdk/azcore/CHANGELOG.md | 2 + sdk/azcore/arm/runtime/pipeline.go | 16 +- sdk/azcore/arm/runtime/policy_bearer_token.go | 95 ++++++++ .../arm/runtime/policy_bearer_token_test.go | 204 ++++++++++++++++++ sdk/azcore/internal/shared/constants.go | 19 +- .../internal/shared/expiring_resource.go | 99 +++++++++ sdk/azcore/runtime/policy_bearer_token.go | 134 +----------- .../runtime/policy_bearer_token_test.go | 46 +--- .../runtime/transport_default_http_client.go | 3 - 9 files changed, 436 insertions(+), 182 deletions(-) create mode 100644 sdk/azcore/arm/runtime/policy_bearer_token.go create mode 100644 sdk/azcore/arm/runtime/policy_bearer_token_test.go create mode 100644 sdk/azcore/internal/shared/expiring_resource.go diff --git a/sdk/azcore/CHANGELOG.md b/sdk/azcore/CHANGELOG.md index bbe90e194b59..0c7e2c295946 100644 --- a/sdk/azcore/CHANGELOG.md +++ b/sdk/azcore/CHANGELOG.md @@ -9,11 +9,13 @@ * `runtime.NewPipeline` has a new signature that simplifies implementing custom authentication * `arm/runtime.RegistrationOptions` embeds `policy.ClientOptions` * Contents in the `log` package have been slightly renamed. +* Removed `AuxiliaryTenants` from `runtime.AuthenticationOptions` as this is ARM-specific ### Features Added * Updating Documentation * Added string typdef `arm.Endpoint` to provide a hint toward expected ARM client endpoints * `azcore.ClientOptions` contains common pipeline configuration settings +* Added support for multi-tenant authorization in `arm/runtime` ### Bug Fixes * Fixed a potential panic when creating the default Transporter. diff --git a/sdk/azcore/arm/runtime/pipeline.go b/sdk/azcore/arm/runtime/pipeline.go index 655b36567904..34e3d89e88b6 100644 --- a/sdk/azcore/arm/runtime/pipeline.go +++ b/sdk/azcore/arm/runtime/pipeline.go @@ -31,13 +31,23 @@ func NewPipeline(module, version string, cred azcore.TokenCredential, options *a perCallPolicies = append(perCallPolicies, NewRPRegistrationPolicy(string(ep), cred, ®RPOpts)) } perRetryPolicies := []policy.Policy{ - azruntime.NewBearerTokenPolicy(cred, azruntime.AuthenticationOptions{ + NewBearerTokenPolicy(cred, AuthenticationOptions{ TokenRequest: policy.TokenRequestOptions{ Scopes: []string{shared.EndpointToScope(string(ep))}, }, AuxiliaryTenants: options.AuxiliaryTenants, - }, - ), + }), } return azruntime.NewPipeline(module, version, perCallPolicies, perRetryPolicies, &options.ClientOptions) } + +// AuthenticationOptions contains various options used to create a credential policy. +type AuthenticationOptions struct { + // TokenRequest is a TokenRequestOptions that includes a scopes field which contains + // the list of OAuth2 authentication scopes used when requesting a token. + // This field is ignored for other forms of authentication (e.g. shared key). + TokenRequest policy.TokenRequestOptions + // AuxiliaryTenants contains a list of additional tenant IDs to be used to authenticate + // in cross-tenant applications. + AuxiliaryTenants []string +} diff --git a/sdk/azcore/arm/runtime/policy_bearer_token.go b/sdk/azcore/arm/runtime/policy_bearer_token.go new file mode 100644 index 000000000000..5af149d3b3e7 --- /dev/null +++ b/sdk/azcore/arm/runtime/policy_bearer_token.go @@ -0,0 +1,95 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package runtime + +import ( + "fmt" + "net/http" + "strings" + "time" + + "github.com/Azure/azure-sdk-for-go/sdk/azcore" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/shared" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/policy" +) + +type acquiringResourceState struct { + req *policy.Request + p BearerTokenPolicy +} + +// acquire acquires or updates the resource; only one +// thread/goroutine at a time ever calls this function +func acquire(state interface{}) (newResource interface{}, newExpiration time.Time, err error) { + s := state.(acquiringResourceState) + tk, err := s.p.cred.GetToken(s.req.Raw().Context(), s.p.options) + if err != nil { + return nil, time.Time{}, err + } + return tk, tk.ExpiresOn, nil +} + +// BearerTokenPolicy authorizes requests with bearer tokens acquired from a TokenCredential. +type BearerTokenPolicy struct { + // mainResource is the resource to be retreived using the tenant specified in the credential + mainResource *shared.ExpiringResource + // auxResources are additional resources that are required for cross-tenant applications + auxResources map[string]*shared.ExpiringResource + // the following fields are read-only + cred azcore.TokenCredential + options policy.TokenRequestOptions +} + +// NewBearerTokenPolicy creates a policy object that authorizes requests with bearer tokens. +// cred: an azcore.TokenCredential implementation such as a credential object from azidentity +// opts: optional settings. Pass nil to accept default values; this is the same as passing a zero-value options. +func NewBearerTokenPolicy(cred azcore.TokenCredential, opts AuthenticationOptions) *BearerTokenPolicy { + p := &BearerTokenPolicy{ + cred: cred, + options: opts.TokenRequest, + mainResource: shared.NewExpiringResource(acquire), + } + if len(opts.AuxiliaryTenants) > 0 { + p.auxResources = map[string]*shared.ExpiringResource{} + } + for _, t := range opts.AuxiliaryTenants { + p.auxResources[t] = shared.NewExpiringResource(acquire) + + } + return p +} + +// Do authorizes a request with a bearer token +func (b *BearerTokenPolicy) Do(req *policy.Request) (*http.Response, error) { + as := acquiringResourceState{ + p: *b, + req: req, + } + tk, err := b.mainResource.GetResource(as) + if err != nil { + return nil, err + } + if token, ok := tk.(*azcore.AccessToken); ok { + req.Raw().Header.Set(shared.HeaderXmsDate, time.Now().UTC().Format(http.TimeFormat)) + req.Raw().Header.Set(shared.HeaderAuthorization, shared.BearerTokenPrefix+token.Token) + } + auxTokens := []string{} + for tenant, er := range b.auxResources { + bCopy := *b + bCopy.options.TenantID = tenant + auxAS := acquiringResourceState{ + p: bCopy, + req: req, + } + auxTk, err := er.GetResource(auxAS) + if err != nil { + return nil, err + } + auxTokens = append(auxTokens, fmt.Sprintf("%s%s", shared.BearerTokenPrefix, auxTk.(*azcore.AccessToken).Token)) + } + if len(auxTokens) > 0 { + req.Raw().Header.Set(shared.HeaderAuxiliaryAuthorization, strings.Join(auxTokens, ", ")) + } + return req.Next() +} diff --git a/sdk/azcore/arm/runtime/policy_bearer_token_test.go b/sdk/azcore/arm/runtime/policy_bearer_token_test.go new file mode 100644 index 000000000000..fc2dec1b6ee7 --- /dev/null +++ b/sdk/azcore/arm/runtime/policy_bearer_token_test.go @@ -0,0 +1,204 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package runtime + +import ( + "context" + "strings" + + "errors" + "net/http" + "testing" + "time" + + "github.com/Azure/azure-sdk-for-go/sdk/azcore" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/arm" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/pipeline" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/shared" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/policy" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/runtime" + "github.com/Azure/azure-sdk-for-go/sdk/internal/mock" +) + +const ( + tokenValue = "***" + accessTokenRespSuccess = `{"access_token": "` + tokenValue + `", "expires_in": 3600}` + accessTokenRespShortLived = `{"access_token": "` + tokenValue + `", "expires_in": 0}` + scope = "scope" +) + +type mockCredential struct { + getTokenImpl func(ctx context.Context, options policy.TokenRequestOptions) (*azcore.AccessToken, error) +} + +func (mc mockCredential) GetToken(ctx context.Context, options policy.TokenRequestOptions) (*azcore.AccessToken, error) { + if mc.getTokenImpl != nil { + return mc.getTokenImpl(ctx, options) + } + return &azcore.AccessToken{Token: "***", ExpiresOn: time.Now().Add(time.Hour)}, nil +} + +func (mc mockCredential) NewAuthenticationPolicy(options AuthenticationOptions) policy.Policy { + return mc +} + +func (mc mockCredential) Do(req *policy.Request) (*http.Response, error) { + return nil, nil +} + +func newTestPipeline(opts *policy.ClientOptions) pipeline.Pipeline { + return runtime.NewPipeline("testmodule", "v0.1.0", nil, nil, opts) +} + +func defaultTestPipeline(srv policy.Transporter, scope string) pipeline.Pipeline { + retryOpts := policy.RetryOptions{ + MaxRetryDelay: 500 * time.Millisecond, + RetryDelay: time.Millisecond, + } + return NewPipeline( + "testmodule", + "v0.1.0", + mockCredential{}, + &arm.ClientOptions{ + ClientOptions: azcore.ClientOptions{ + Retry: retryOpts, + Transport: srv, + }, + }) +} + +func TestBearerPolicy_SuccessGetToken(t *testing.T) { + srv, close := mock.NewTLSServer() + defer close() + srv.AppendResponse(mock.WithBody([]byte(accessTokenRespSuccess))) + srv.AppendResponse(mock.WithStatusCode(http.StatusOK)) + pipeline := defaultTestPipeline(srv, scope) + req, err := runtime.NewRequest(context.Background(), http.MethodGet, srv.URL()) + if err != nil { + t.Fatal(err) + } + resp, err := pipeline.Do(req) + if err != nil { + t.Fatalf("Expected nil error but received one") + } + const expectedToken = shared.BearerTokenPrefix + tokenValue + if token := resp.Request.Header.Get(shared.HeaderAuthorization); token != expectedToken { + t.Fatalf("expected token '%s', got '%s'", expectedToken, token) + } +} + +func TestBearerPolicy_CredentialFailGetToken(t *testing.T) { + srv, close := mock.NewTLSServer() + defer close() + expectedErr := errors.New("oops") + failCredential := mockCredential{} + failCredential.getTokenImpl = func(ctx context.Context, options policy.TokenRequestOptions) (*azcore.AccessToken, error) { + return nil, expectedErr + } + b := NewBearerTokenPolicy(failCredential, AuthenticationOptions{}) + pipeline := newTestPipeline(&policy.ClientOptions{ + Transport: srv, + Retry: policy.RetryOptions{ + RetryDelay: 10 * time.Millisecond, + }, + PerRetryPolicies: []policy.Policy{b}, + }) + req, err := runtime.NewRequest(context.Background(), http.MethodGet, srv.URL()) + if err != nil { + t.Fatal(err) + } + resp, err := pipeline.Do(req) + if err != expectedErr { + t.Fatalf("unexpected error: %v", err) + } + if resp != nil { + t.Fatal("expected nil response") + } +} + +func TestBearerTokenPolicy_TokenExpired(t *testing.T) { + srv, close := mock.NewTLSServer() + defer close() + srv.AppendResponse(mock.WithBody([]byte(accessTokenRespShortLived))) + srv.AppendResponse(mock.WithStatusCode(http.StatusOK)) + pipeline := defaultTestPipeline(srv, scope) + req, err := runtime.NewRequest(context.Background(), http.MethodGet, srv.URL()) + if err != nil { + t.Fatal(err) + } + _, err = pipeline.Do(req) + if err != nil { + t.Fatalf("unexpected error %v", err) + } + _, err = pipeline.Do(req) + if err != nil { + t.Fatalf("unexpected error %v", err) + } +} + +func TestBearerPolicy_GetTokenFailsNoDeadlock(t *testing.T) { + srv, close := mock.NewTLSServer() + defer close() + srv.AppendResponse(mock.WithBody([]byte(accessTokenRespSuccess))) + retryOpts := policy.RetryOptions{ + // use a negative try timeout to trigger a deadline exceeded error causing GetToken() to fail + TryTimeout: -1 * time.Nanosecond, + MaxRetryDelay: 500 * time.Millisecond, + RetryDelay: 50 * time.Millisecond, + MaxRetries: 3, + } + b := NewBearerTokenPolicy(mockCredential{}, AuthenticationOptions{}) + pipeline := newTestPipeline(&policy.ClientOptions{Transport: srv, Retry: retryOpts, PerRetryPolicies: []pipeline.Policy{b}}) + req, err := runtime.NewRequest(context.Background(), http.MethodGet, srv.URL()) + if err != nil { + t.Fatal(err) + } + resp, err := pipeline.Do(req) + if err == nil { + t.Fatal("unexpected nil error") + } + if resp != nil { + t.Fatal("expected nil response") + } +} + +func TestBearerTokenWithAuxiliaryTenants(t *testing.T) { + srv, close := mock.NewTLSServer() + defer close() + srv.AppendResponse(mock.WithBody([]byte(accessTokenRespSuccess))) + srv.AppendResponse(mock.WithBody([]byte(accessTokenRespSuccess))) + srv.AppendResponse(mock.WithBody([]byte(accessTokenRespSuccess))) + srv.AppendResponse(mock.WithBody([]byte(accessTokenRespSuccess))) + srv.AppendResponse() + retryOpts := policy.RetryOptions{ + MaxRetryDelay: 500 * time.Millisecond, + RetryDelay: 50 * time.Millisecond, + } + b := NewBearerTokenPolicy( + mockCredential{}, + AuthenticationOptions{ + TokenRequest: policy.TokenRequestOptions{ + Scopes: []string{scope}, + }, + AuxiliaryTenants: []string{"tenant1", "tenant2", "tenant3"}, + }, + ) + pipeline := newTestPipeline(&policy.ClientOptions{Transport: srv, Retry: retryOpts, PerRetryPolicies: []pipeline.Policy{b}}) + req, err := runtime.NewRequest(context.Background(), http.MethodGet, srv.URL()) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + resp, err := pipeline.Do(req) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + if resp.StatusCode != http.StatusOK { + t.Fatalf("unexpected status code: %d", resp.StatusCode) + } + expectedHeader := strings.Repeat(shared.BearerTokenPrefix+tokenValue+", ", 3) + expectedHeader = expectedHeader[:len(expectedHeader)-2] + if auxH := resp.Request.Header.Get(shared.HeaderAuxiliaryAuthorization); auxH != expectedHeader { + t.Fatalf("unexpected auxiliary authorization header %s", auxH) + } +} diff --git a/sdk/azcore/internal/shared/constants.go b/sdk/azcore/internal/shared/constants.go index 06c5d32fd03c..5103d87e0053 100644 --- a/sdk/azcore/internal/shared/constants.go +++ b/sdk/azcore/internal/shared/constants.go @@ -12,19 +12,24 @@ const ( ) const ( - HeaderAzureAsync = "Azure-AsyncOperation" - HeaderContentLength = "Content-Length" - HeaderContentType = "Content-Type" - HeaderLocation = "Location" - HeaderOperationLocation = "Operation-Location" - HeaderRetryAfter = "Retry-After" - HeaderUserAgent = "User-Agent" + HeaderAuthorization = "Authorization" + HeaderAuxiliaryAuthorization = "x-ms-authorization-auxiliary" + HeaderAzureAsync = "Azure-AsyncOperation" + HeaderContentLength = "Content-Length" + HeaderContentType = "Content-Type" + HeaderLocation = "Location" + HeaderOperationLocation = "Operation-Location" + HeaderRetryAfter = "Retry-After" + HeaderUserAgent = "User-Agent" + HeaderXmsDate = "x-ms-date" ) const ( DefaultMaxRetries = 3 ) +const BearerTokenPrefix = "Bearer " + const ( // Module is the name of the calling module used in telemetry data. Module = "azcore" diff --git a/sdk/azcore/internal/shared/expiring_resource.go b/sdk/azcore/internal/shared/expiring_resource.go new file mode 100644 index 000000000000..9f97ca9559ab --- /dev/null +++ b/sdk/azcore/internal/shared/expiring_resource.go @@ -0,0 +1,99 @@ +//go:build go1.16 +// +build go1.16 + +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package shared + +import ( + "sync" + "time" +) + +// AcquireResource abstracts a method for refreshing an expiring resource. +type AcquireResource func(state interface{}) (newResource interface{}, newExpiration time.Time, err error) + +// ExpiringResource is a temporal resource (usually a credential), that requires periodic refreshing. +type ExpiringResource struct { + // cond is used to synchronize access to the shared resource embodied by the remaining fields + cond *sync.Cond + + // acquiring indicates that some thread/goroutine is in the process of acquiring/updating the resource + acquiring bool + + // resource contains the value of the shared resource + resource interface{} + + // expiration indicates when the shared resource expires; it is 0 if the resource was never acquired + expiration time.Time + + // acquireResource is the callback function that actually acquires the resource + acquireResource AcquireResource +} + +// NewExpiringResource creates a new ExpiringResource that uses the specified AcquireResource for refreshing. +func NewExpiringResource(ar AcquireResource) *ExpiringResource { + return &ExpiringResource{cond: sync.NewCond(&sync.Mutex{}), acquireResource: ar} +} + +// GetResource returns the underlying resource. +// If the resource is fresh, no refresh is performed. +func (er *ExpiringResource) GetResource(state interface{}) (interface{}, error) { + // If the resource is expiring within this time window, update it eagerly. + // This allows other threads/goroutines to keep running by using the not-yet-expired + // resource value while one thread/goroutine updates the resource. + const window = 2 * time.Minute // This example updates the resource 2 minutes prior to expiration + + now, acquire, resource := time.Now(), false, er.resource + // acquire exclusive lock + er.cond.L.Lock() + for { + if er.expiration.IsZero() || er.expiration.Before(now) { + // The resource was never acquired or has expired + if !er.acquiring { + // If another thread/goroutine is not acquiring/updating the resource, this thread/goroutine will do it + er.acquiring, acquire = true, true + break + } + // Getting here means that this thread/goroutine will wait for the updated resource + } else if er.expiration.Add(-window).Before(now) { + // The resource is valid but is expiring within the time window + if !er.acquiring { + // If another thread/goroutine is not acquiring/renewing the resource, this thread/goroutine will do it + er.acquiring, acquire = true, true + break + } + // This thread/goroutine will use the existing resource value while another updates it + resource = er.resource + break + } else { + // The resource is not close to expiring, this thread/goroutine should use its current value + resource = er.resource + break + } + // If we get here, wait for the new resource value to be acquired/updated + er.cond.Wait() + } + er.cond.L.Unlock() // Release the lock so no threads/goroutines are blocked + + var err error + if acquire { + // This thread/goroutine has been selected to acquire/update the resource + var expiration time.Time + resource, expiration, err = er.acquireResource(state) + + // Atomically, update the shared resource's new value & expiration. + er.cond.L.Lock() + if err == nil { + // No error, update resource & expiration + er.resource, er.expiration = resource, expiration + } + er.acquiring = false // Indicate that no thread/goroutine is currently acquiring the resrouce + + // Wake up any waiting threads/goroutines since there is a resource they can ALL use + er.cond.L.Unlock() + er.cond.Broadcast() + } + return resource, err // Return the resource this thread/goroutine can use +} diff --git a/sdk/azcore/runtime/policy_bearer_token.go b/sdk/azcore/runtime/policy_bearer_token.go index 2c5a22b6c401..679f8159cf30 100644 --- a/sdk/azcore/runtime/policy_bearer_token.go +++ b/sdk/azcore/runtime/policy_bearer_token.go @@ -4,56 +4,26 @@ package runtime import ( - "fmt" "net/http" - "strings" - "sync" "time" "github.com/Azure/azure-sdk-for-go/sdk/azcore" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/shared" "github.com/Azure/azure-sdk-for-go/sdk/azcore/policy" ) -const ( - bearerTokenPrefix = "Bearer " - headerXmsDate = "x-ms-date" - headerAuthorization = "Authorization" - headerAuxiliaryAuthorization = "x-ms-authorization-auxiliary" -) - // BearerTokenPolicy authorizes requests with bearer tokens acquired from a TokenCredential. type BearerTokenPolicy struct { // mainResource is the resource to be retreived using the tenant specified in the credential - mainResource *expiringResource - // auxResources are additional resources that are required for cross-tenant applications - auxResources map[string]*expiringResource + mainResource *shared.ExpiringResource // the following fields are read-only cred azcore.TokenCredential options policy.TokenRequestOptions } -type expiringResource struct { - // cond is used to synchronize access to the shared resource embodied by the remaining fields - cond *sync.Cond - - // acquiring indicates that some thread/goroutine is in the process of acquiring/updating the resource - acquiring bool - - // resource contains the value of the shared resource - resource interface{} - - // expiration indicates when the shared resource expires; it is 0 if the resource was never acquired - expiration time.Time - - // acquireResource is the callback function that actually acquires the resource - acquireResource acquireResource -} - -type acquireResource func(state interface{}) (newResource interface{}, newExpiration time.Time, err error) - type acquiringResourceState struct { req *policy.Request - p BearerTokenPolicy + p *BearerTokenPolicy } // acquire acquires or updates the resource; only one @@ -67,92 +37,21 @@ func acquire(state interface{}) (newResource interface{}, newExpiration time.Tim return tk, tk.ExpiresOn, nil } -func newExpiringResource(ar acquireResource) *expiringResource { - return &expiringResource{cond: sync.NewCond(&sync.Mutex{}), acquireResource: ar} -} - -func (er *expiringResource) GetResource(state interface{}) (interface{}, error) { - // If the resource is expiring within this time window, update it eagerly. - // This allows other threads/goroutines to keep running by using the not-yet-expired - // resource value while one thread/goroutine updates the resource. - const window = 2 * time.Minute // This example updates the resource 2 minutes prior to expiration - - now, acquire, resource := time.Now(), false, er.resource - // acquire exclusive lock - er.cond.L.Lock() - for { - if er.expiration.IsZero() || er.expiration.Before(now) { - // The resource was never acquired or has expired - if !er.acquiring { - // If another thread/goroutine is not acquiring/updating the resource, this thread/goroutine will do it - er.acquiring, acquire = true, true - break - } - // Getting here means that this thread/goroutine will wait for the updated resource - } else if er.expiration.Add(-window).Before(now) { - // The resource is valid but is expiring within the time window - if !er.acquiring { - // If another thread/goroutine is not acquiring/renewing the resource, this thread/goroutine will do it - er.acquiring, acquire = true, true - break - } - // This thread/goroutine will use the existing resource value while another updates it - resource = er.resource - break - } else { - // The resource is not close to expiring, this thread/goroutine should use its current value - resource = er.resource - break - } - // If we get here, wait for the new resource value to be acquired/updated - er.cond.Wait() - } - er.cond.L.Unlock() // Release the lock so no threads/goroutines are blocked - - var err error - if acquire { - // This thread/goroutine has been selected to acquire/update the resource - var expiration time.Time - resource, expiration, err = er.acquireResource(state) - - // Atomically, update the shared resource's new value & expiration. - er.cond.L.Lock() - if err == nil { - // No error, update resource & expiration - er.resource, er.expiration = resource, expiration - } - er.acquiring = false // Indicate that no thread/goroutine is currently acquiring the resrouce - - // Wake up any waiting threads/goroutines since there is a resource they can ALL use - er.cond.L.Unlock() - er.cond.Broadcast() - } - return resource, err // Return the resource this thread/goroutine can use -} - // NewBearerTokenPolicy creates a policy object that authorizes requests with bearer tokens. // cred: an azcore.TokenCredential implementation such as a credential object from azidentity // opts: optional settings. Pass nil to accept default values; this is the same as passing a zero-value options. func NewBearerTokenPolicy(cred azcore.TokenCredential, opts AuthenticationOptions) *BearerTokenPolicy { - p := &BearerTokenPolicy{ + return &BearerTokenPolicy{ cred: cred, options: opts.TokenRequest, - mainResource: newExpiringResource(acquire), - } - if len(opts.AuxiliaryTenants) > 0 { - p.auxResources = map[string]*expiringResource{} - } - for _, t := range opts.AuxiliaryTenants { - p.auxResources[t] = newExpiringResource(acquire) - + mainResource: shared.NewExpiringResource(acquire), } - return p } // Do authorizes a request with a bearer token func (b *BearerTokenPolicy) Do(req *policy.Request) (*http.Response, error) { as := acquiringResourceState{ - p: *b, + p: b, req: req, } tk, err := b.mainResource.GetResource(as) @@ -160,25 +59,8 @@ func (b *BearerTokenPolicy) Do(req *policy.Request) (*http.Response, error) { return nil, err } if token, ok := tk.(*azcore.AccessToken); ok { - req.Raw().Header.Set(headerXmsDate, time.Now().UTC().Format(http.TimeFormat)) - req.Raw().Header.Set(headerAuthorization, bearerTokenPrefix+token.Token) - } - auxTokens := []string{} - for tenant, er := range b.auxResources { - bCopy := *b - bCopy.options.TenantID = tenant - auxAS := acquiringResourceState{ - p: bCopy, - req: req, - } - auxTk, err := er.GetResource(auxAS) - if err != nil { - return nil, err - } - auxTokens = append(auxTokens, fmt.Sprintf("%s%s", bearerTokenPrefix, auxTk.(*azcore.AccessToken).Token)) - } - if len(auxTokens) > 0 { - req.Raw().Header.Set(headerAuxiliaryAuthorization, strings.Join(auxTokens, ", ")) + req.Raw().Header.Set(shared.HeaderXmsDate, time.Now().UTC().Format(http.TimeFormat)) + req.Raw().Header.Set(shared.HeaderAuthorization, shared.BearerTokenPrefix+token.Token) } return req.Next() } diff --git a/sdk/azcore/runtime/policy_bearer_token_test.go b/sdk/azcore/runtime/policy_bearer_token_test.go index 36440cbff8dd..39d13682ebfc 100644 --- a/sdk/azcore/runtime/policy_bearer_token_test.go +++ b/sdk/azcore/runtime/policy_bearer_token_test.go @@ -5,7 +5,6 @@ package runtime import ( "context" - "strings" "errors" "net/http" @@ -14,6 +13,7 @@ import ( "github.com/Azure/azure-sdk-for-go/sdk/azcore" "github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/pipeline" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/shared" "github.com/Azure/azure-sdk-for-go/sdk/azcore/policy" "github.com/Azure/azure-sdk-for-go/sdk/internal/mock" ) @@ -76,8 +76,8 @@ func TestBearerPolicy_SuccessGetToken(t *testing.T) { if err != nil { t.Fatalf("Expected nil error but received one") } - const expectedToken = bearerTokenPrefix + tokenValue - if token := resp.Request.Header.Get(headerAuthorization); token != expectedToken { + const expectedToken = shared.BearerTokenPrefix + tokenValue + if token := resp.Request.Header.Get(shared.HeaderAuthorization); token != expectedToken { t.Fatalf("expected token '%s', got '%s'", expectedToken, token) } } @@ -156,43 +156,3 @@ func TestBearerPolicy_GetTokenFailsNoDeadlock(t *testing.T) { t.Fatal("expected nil response") } } - -func TestBearerTokenWithAuxiliaryTenants(t *testing.T) { - srv, close := mock.NewTLSServer() - defer close() - srv.AppendResponse(mock.WithBody([]byte(accessTokenRespSuccess))) - srv.AppendResponse(mock.WithBody([]byte(accessTokenRespSuccess))) - srv.AppendResponse(mock.WithBody([]byte(accessTokenRespSuccess))) - srv.AppendResponse(mock.WithBody([]byte(accessTokenRespSuccess))) - srv.AppendResponse() - retryOpts := policy.RetryOptions{ - MaxRetryDelay: 500 * time.Millisecond, - RetryDelay: 50 * time.Millisecond, - } - b := NewBearerTokenPolicy( - mockCredential{}, - AuthenticationOptions{ - TokenRequest: policy.TokenRequestOptions{ - Scopes: []string{scope}, - }, - AuxiliaryTenants: []string{"tenant1", "tenant2", "tenant3"}, - }, - ) - pipeline := newTestPipeline(&policy.ClientOptions{Transport: srv, Retry: retryOpts, PerRetryPolicies: []pipeline.Policy{b}}) - req, err := NewRequest(context.Background(), http.MethodGet, srv.URL()) - if err != nil { - t.Fatalf("Unexpected error: %v", err) - } - resp, err := pipeline.Do(req) - if err != nil { - t.Fatalf("Unexpected error: %v", err) - } - if resp.StatusCode != http.StatusOK { - t.Fatalf("unexpected status code: %d", resp.StatusCode) - } - expectedHeader := strings.Repeat(bearerTokenPrefix+tokenValue+", ", 3) - expectedHeader = expectedHeader[:len(expectedHeader)-2] - if auxH := resp.Request.Header.Get(headerAuxiliaryAuthorization); auxH != expectedHeader { - t.Fatalf("unexpected auxiliary authorization header %s", auxH) - } -} diff --git a/sdk/azcore/runtime/transport_default_http_client.go b/sdk/azcore/runtime/transport_default_http_client.go index d8bb8643c2ae..3fe1fa2435e5 100644 --- a/sdk/azcore/runtime/transport_default_http_client.go +++ b/sdk/azcore/runtime/transport_default_http_client.go @@ -44,7 +44,4 @@ type AuthenticationOptions struct { // the list of OAuth2 authentication scopes used when requesting a token. // This field is ignored for other forms of authentication (e.g. shared key). TokenRequest policy.TokenRequestOptions - // AuxiliaryTenants contains a list of additional tenant IDs to be used to authenticate - // in cross-tenant applications. - AuxiliaryTenants []string } From 049a9aa75c93671c18aa809773fa7108fa709dbe Mon Sep 17 00:00:00 2001 From: Joel Hendrix Date: Wed, 20 Oct 2021 15:23:31 -0700 Subject: [PATCH 2/5] add tests for expiring resource --- .../internal/shared/expiring_resource_test.go | 48 +++++++++++++++++++ 1 file changed, 48 insertions(+) create mode 100644 sdk/azcore/internal/shared/expiring_resource_test.go diff --git a/sdk/azcore/internal/shared/expiring_resource_test.go b/sdk/azcore/internal/shared/expiring_resource_test.go new file mode 100644 index 000000000000..9ccc244ced01 --- /dev/null +++ b/sdk/azcore/internal/shared/expiring_resource_test.go @@ -0,0 +1,48 @@ +//go:build go1.16 +// +build go1.16 + +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package shared + +import ( + "errors" + "testing" + "time" + + "github.com/stretchr/testify/require" +) + +func TestNewExpiringResource(t *testing.T) { + er := NewExpiringResource(func(state interface{}) (newResource interface{}, newExpiration time.Time, err error) { + s := state.(string) + switch s { + case "initial": + return "updated", time.Now(), nil + case "updated": + return "refreshed", time.Now().Add(1 * time.Hour), nil + default: + t.Fatalf("unexpected state %s", s) + return "", time.Time{}, errors.New("unexpected") + } + }) + res, err := er.GetResource("initial") + require.NoError(t, err) + require.Equal(t, "updated", res) + res, err = er.GetResource(res) + require.NoError(t, err) + require.Equal(t, "refreshed", res) + res, err = er.GetResource(res) + require.NoError(t, err) + require.Equal(t, "refreshed", res) +} + +func TestNewExpiringResourceError(t *testing.T) { + er := NewExpiringResource(func(state interface{}) (newResource interface{}, newExpiration time.Time, err error) { + return "", time.Time{}, errors.New("failed") + }) + res, err := er.GetResource("stale") + require.Error(t, err) + require.Equal(t, "", res) +} From 3dc7f96cb16b81c647d5371d80268d8275547cf7 Mon Sep 17 00:00:00 2001 From: Joel Hendrix Date: Wed, 20 Oct 2021 15:55:49 -0700 Subject: [PATCH 3/5] remove superfluous x-ms-date header --- sdk/azcore/arm/runtime/policy_bearer_token.go | 1 - sdk/azcore/runtime/policy_bearer_token.go | 1 - 2 files changed, 2 deletions(-) diff --git a/sdk/azcore/arm/runtime/policy_bearer_token.go b/sdk/azcore/arm/runtime/policy_bearer_token.go index 5af149d3b3e7..a4d9fd0300a7 100644 --- a/sdk/azcore/arm/runtime/policy_bearer_token.go +++ b/sdk/azcore/arm/runtime/policy_bearer_token.go @@ -71,7 +71,6 @@ func (b *BearerTokenPolicy) Do(req *policy.Request) (*http.Response, error) { return nil, err } if token, ok := tk.(*azcore.AccessToken); ok { - req.Raw().Header.Set(shared.HeaderXmsDate, time.Now().UTC().Format(http.TimeFormat)) req.Raw().Header.Set(shared.HeaderAuthorization, shared.BearerTokenPrefix+token.Token) } auxTokens := []string{} diff --git a/sdk/azcore/runtime/policy_bearer_token.go b/sdk/azcore/runtime/policy_bearer_token.go index 679f8159cf30..0dc80e914659 100644 --- a/sdk/azcore/runtime/policy_bearer_token.go +++ b/sdk/azcore/runtime/policy_bearer_token.go @@ -59,7 +59,6 @@ func (b *BearerTokenPolicy) Do(req *policy.Request) (*http.Response, error) { return nil, err } if token, ok := tk.(*azcore.AccessToken); ok { - req.Raw().Header.Set(shared.HeaderXmsDate, time.Now().UTC().Format(http.TimeFormat)) req.Raw().Header.Set(shared.HeaderAuthorization, shared.BearerTokenPrefix+token.Token) } return req.Next() From 2087c4a6104cb8770db8d85dc9eb46d523f95bae Mon Sep 17 00:00:00 2001 From: Joel Hendrix Date: Thu, 21 Oct 2021 08:06:05 -0700 Subject: [PATCH 4/5] remove policy.TokenRequestOptions from AuthenticationOptions --- sdk/azcore/arm/runtime/pipeline.go | 10 ++----- sdk/azcore/arm/runtime/policy_bearer_token.go | 28 +++++++++---------- .../arm/runtime/policy_bearer_token_test.go | 4 +-- 3 files changed, 18 insertions(+), 24 deletions(-) diff --git a/sdk/azcore/arm/runtime/pipeline.go b/sdk/azcore/arm/runtime/pipeline.go index 34e3d89e88b6..c67d5ed1f988 100644 --- a/sdk/azcore/arm/runtime/pipeline.go +++ b/sdk/azcore/arm/runtime/pipeline.go @@ -32,9 +32,7 @@ func NewPipeline(module, version string, cred azcore.TokenCredential, options *a } perRetryPolicies := []policy.Policy{ NewBearerTokenPolicy(cred, AuthenticationOptions{ - TokenRequest: policy.TokenRequestOptions{ - Scopes: []string{shared.EndpointToScope(string(ep))}, - }, + Scopes: []string{shared.EndpointToScope(string(ep))}, AuxiliaryTenants: options.AuxiliaryTenants, }), } @@ -43,10 +41,8 @@ func NewPipeline(module, version string, cred azcore.TokenCredential, options *a // AuthenticationOptions contains various options used to create a credential policy. type AuthenticationOptions struct { - // TokenRequest is a TokenRequestOptions that includes a scopes field which contains - // the list of OAuth2 authentication scopes used when requesting a token. - // This field is ignored for other forms of authentication (e.g. shared key). - TokenRequest policy.TokenRequestOptions + // Scopes contains the list of permission scopes required for the token. + Scopes []string // AuxiliaryTenants contains a list of additional tenant IDs to be used to authenticate // in cross-tenant applications. AuxiliaryTenants []string diff --git a/sdk/azcore/arm/runtime/policy_bearer_token.go b/sdk/azcore/arm/runtime/policy_bearer_token.go index a4d9fd0300a7..18de807c077d 100644 --- a/sdk/azcore/arm/runtime/policy_bearer_token.go +++ b/sdk/azcore/arm/runtime/policy_bearer_token.go @@ -4,6 +4,7 @@ package runtime import ( + "context" "fmt" "net/http" "strings" @@ -15,15 +16,19 @@ import ( ) type acquiringResourceState struct { - req *policy.Request - p BearerTokenPolicy + ctx context.Context + p *BearerTokenPolicy + tenant string } // acquire acquires or updates the resource; only one // thread/goroutine at a time ever calls this function func acquire(state interface{}) (newResource interface{}, newExpiration time.Time, err error) { s := state.(acquiringResourceState) - tk, err := s.p.cred.GetToken(s.req.Raw().Context(), s.p.options) + tk, err := s.p.cred.GetToken(s.ctx, policy.TokenRequestOptions{ + Scopes: s.p.options.Scopes, + TenantID: s.tenant, + }) if err != nil { return nil, time.Time{}, err } @@ -38,7 +43,7 @@ type BearerTokenPolicy struct { auxResources map[string]*shared.ExpiringResource // the following fields are read-only cred azcore.TokenCredential - options policy.TokenRequestOptions + options AuthenticationOptions } // NewBearerTokenPolicy creates a policy object that authorizes requests with bearer tokens. @@ -47,7 +52,7 @@ type BearerTokenPolicy struct { func NewBearerTokenPolicy(cred azcore.TokenCredential, opts AuthenticationOptions) *BearerTokenPolicy { p := &BearerTokenPolicy{ cred: cred, - options: opts.TokenRequest, + options: opts, mainResource: shared.NewExpiringResource(acquire), } if len(opts.AuxiliaryTenants) > 0 { @@ -63,8 +68,8 @@ func NewBearerTokenPolicy(cred azcore.TokenCredential, opts AuthenticationOption // Do authorizes a request with a bearer token func (b *BearerTokenPolicy) Do(req *policy.Request) (*http.Response, error) { as := acquiringResourceState{ - p: *b, - req: req, + ctx: req.Raw().Context(), + p: b, } tk, err := b.mainResource.GetResource(as) if err != nil { @@ -75,13 +80,8 @@ func (b *BearerTokenPolicy) Do(req *policy.Request) (*http.Response, error) { } auxTokens := []string{} for tenant, er := range b.auxResources { - bCopy := *b - bCopy.options.TenantID = tenant - auxAS := acquiringResourceState{ - p: bCopy, - req: req, - } - auxTk, err := er.GetResource(auxAS) + as.tenant = tenant + auxTk, err := er.GetResource(as) if err != nil { return nil, err } diff --git a/sdk/azcore/arm/runtime/policy_bearer_token_test.go b/sdk/azcore/arm/runtime/policy_bearer_token_test.go index fc2dec1b6ee7..a5c07e0f6a9a 100644 --- a/sdk/azcore/arm/runtime/policy_bearer_token_test.go +++ b/sdk/azcore/arm/runtime/policy_bearer_token_test.go @@ -178,9 +178,7 @@ func TestBearerTokenWithAuxiliaryTenants(t *testing.T) { b := NewBearerTokenPolicy( mockCredential{}, AuthenticationOptions{ - TokenRequest: policy.TokenRequestOptions{ - Scopes: []string{scope}, - }, + Scopes: []string{scope}, AuxiliaryTenants: []string{"tenant1", "tenant2", "tenant3"}, }, ) From d15b6526fe95832bbf37887c86246940aed72094 Mon Sep 17 00:00:00 2001 From: Joel Hendrix Date: Thu, 21 Oct 2021 12:17:46 -0700 Subject: [PATCH 5/5] refactor bearer token policy constructors --- sdk/azcore/CHANGELOG.md | 4 +- sdk/azcore/arm/policy/policy.go | 44 +++++++++++++++++++ sdk/azcore/arm/runtime/pipeline.go | 20 +++------ sdk/azcore/arm/runtime/policy_bearer_token.go | 16 ++++--- .../arm/runtime/policy_bearer_token_test.go | 39 ++++++++-------- sdk/azcore/arm/runtime/policy_register_rp.go | 43 +++++------------- .../arm/runtime/policy_register_rp_test.go | 19 ++++---- sdk/azcore/policy/policy.go | 5 +++ sdk/azcore/runtime/policy_bearer_token.go | 11 ++--- .../runtime/policy_bearer_token_test.go | 11 ++--- .../runtime/transport_default_http_client.go | 10 ----- 11 files changed, 119 insertions(+), 103 deletions(-) create mode 100644 sdk/azcore/arm/policy/policy.go diff --git a/sdk/azcore/CHANGELOG.md b/sdk/azcore/CHANGELOG.md index 0c7e2c295946..007dcf14b951 100644 --- a/sdk/azcore/CHANGELOG.md +++ b/sdk/azcore/CHANGELOG.md @@ -9,7 +9,9 @@ * `runtime.NewPipeline` has a new signature that simplifies implementing custom authentication * `arm/runtime.RegistrationOptions` embeds `policy.ClientOptions` * Contents in the `log` package have been slightly renamed. -* Removed `AuxiliaryTenants` from `runtime.AuthenticationOptions` as this is ARM-specific +* Removed `AuthenticationOptions` in favor of `policy.BearerTokenOptions` +* Changed parameters for `NewBearerTokenPolicy()` +* Moved policy config options out of `arm/runtime` and into `arm/policy` ### Features Added * Updating Documentation diff --git a/sdk/azcore/arm/policy/policy.go b/sdk/azcore/arm/policy/policy.go new file mode 100644 index 000000000000..f49dbc313282 --- /dev/null +++ b/sdk/azcore/arm/policy/policy.go @@ -0,0 +1,44 @@ +//go:build go1.16 +// +build go1.16 + +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package policy + +import ( + "time" + + "github.com/Azure/azure-sdk-for-go/sdk/azcore/policy" +) + +// BearerTokenOptions configures the bearer token policy's behavior. +type BearerTokenOptions struct { + // Scopes contains the list of permission scopes required for the token. + Scopes []string + // AuxiliaryTenants contains a list of additional tenant IDs to be used to authenticate + // in cross-tenant applications. + AuxiliaryTenants []string +} + +// RegistrationOptions configures the registration policy's behavior. +// All zero-value fields will be initialized with their default values. +type RegistrationOptions struct { + policy.ClientOptions + + // MaxAttempts is the total number of times to attempt automatic registration + // in the event that an attempt fails. + // The default value is 3. + // Set to a value less than zero to disable the policy. + MaxAttempts int + + // PollingDelay is the amount of time to sleep between polling intervals. + // The default value is 15 seconds. + // A value less than zero means no delay between polling intervals (not recommended). + PollingDelay time.Duration + + // PollingDuration is the amount of time to wait before abandoning polling. + // The default valule is 5 minutes. + // NOTE: Setting this to a small value might cause the policy to prematurely fail. + PollingDuration time.Duration +} diff --git a/sdk/azcore/arm/runtime/pipeline.go b/sdk/azcore/arm/runtime/pipeline.go index c67d5ed1f988..cc1974d3f2b7 100644 --- a/sdk/azcore/arm/runtime/pipeline.go +++ b/sdk/azcore/arm/runtime/pipeline.go @@ -9,9 +9,10 @@ package runtime import ( "github.com/Azure/azure-sdk-for-go/sdk/azcore" "github.com/Azure/azure-sdk-for-go/sdk/azcore/arm" + armpolicy "github.com/Azure/azure-sdk-for-go/sdk/azcore/arm/policy" "github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/pipeline" "github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/shared" - "github.com/Azure/azure-sdk-for-go/sdk/azcore/policy" + azpolicy "github.com/Azure/azure-sdk-for-go/sdk/azcore/policy" azruntime "github.com/Azure/azure-sdk-for-go/sdk/azcore/runtime" ) @@ -25,25 +26,16 @@ func NewPipeline(module, version string, cred azcore.TokenCredential, options *a if len(ep) == 0 { ep = arm.AzurePublicCloud } - perCallPolicies := []policy.Policy{} + perCallPolicies := []azpolicy.Policy{} if !options.DisableRPRegistration { - regRPOpts := RegistrationOptions{ClientOptions: options.ClientOptions} + regRPOpts := armpolicy.RegistrationOptions{ClientOptions: options.ClientOptions} perCallPolicies = append(perCallPolicies, NewRPRegistrationPolicy(string(ep), cred, ®RPOpts)) } - perRetryPolicies := []policy.Policy{ - NewBearerTokenPolicy(cred, AuthenticationOptions{ + perRetryPolicies := []azpolicy.Policy{ + NewBearerTokenPolicy(cred, &armpolicy.BearerTokenOptions{ Scopes: []string{shared.EndpointToScope(string(ep))}, AuxiliaryTenants: options.AuxiliaryTenants, }), } return azruntime.NewPipeline(module, version, perCallPolicies, perRetryPolicies, &options.ClientOptions) } - -// AuthenticationOptions contains various options used to create a credential policy. -type AuthenticationOptions struct { - // Scopes contains the list of permission scopes required for the token. - Scopes []string - // AuxiliaryTenants contains a list of additional tenant IDs to be used to authenticate - // in cross-tenant applications. - AuxiliaryTenants []string -} diff --git a/sdk/azcore/arm/runtime/policy_bearer_token.go b/sdk/azcore/arm/runtime/policy_bearer_token.go index 18de807c077d..ada0405e8f3d 100644 --- a/sdk/azcore/arm/runtime/policy_bearer_token.go +++ b/sdk/azcore/arm/runtime/policy_bearer_token.go @@ -11,8 +11,9 @@ import ( "time" "github.com/Azure/azure-sdk-for-go/sdk/azcore" + armpolicy "github.com/Azure/azure-sdk-for-go/sdk/azcore/arm/policy" "github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/shared" - "github.com/Azure/azure-sdk-for-go/sdk/azcore/policy" + azpolicy "github.com/Azure/azure-sdk-for-go/sdk/azcore/policy" ) type acquiringResourceState struct { @@ -25,7 +26,7 @@ type acquiringResourceState struct { // thread/goroutine at a time ever calls this function func acquire(state interface{}) (newResource interface{}, newExpiration time.Time, err error) { s := state.(acquiringResourceState) - tk, err := s.p.cred.GetToken(s.ctx, policy.TokenRequestOptions{ + tk, err := s.p.cred.GetToken(s.ctx, azpolicy.TokenRequestOptions{ Scopes: s.p.options.Scopes, TenantID: s.tenant, }) @@ -43,16 +44,19 @@ type BearerTokenPolicy struct { auxResources map[string]*shared.ExpiringResource // the following fields are read-only cred azcore.TokenCredential - options AuthenticationOptions + options armpolicy.BearerTokenOptions } // NewBearerTokenPolicy creates a policy object that authorizes requests with bearer tokens. // cred: an azcore.TokenCredential implementation such as a credential object from azidentity // opts: optional settings. Pass nil to accept default values; this is the same as passing a zero-value options. -func NewBearerTokenPolicy(cred azcore.TokenCredential, opts AuthenticationOptions) *BearerTokenPolicy { +func NewBearerTokenPolicy(cred azcore.TokenCredential, opts *armpolicy.BearerTokenOptions) *BearerTokenPolicy { + if opts == nil { + opts = &armpolicy.BearerTokenOptions{} + } p := &BearerTokenPolicy{ cred: cred, - options: opts, + options: *opts, mainResource: shared.NewExpiringResource(acquire), } if len(opts.AuxiliaryTenants) > 0 { @@ -66,7 +70,7 @@ func NewBearerTokenPolicy(cred azcore.TokenCredential, opts AuthenticationOption } // Do authorizes a request with a bearer token -func (b *BearerTokenPolicy) Do(req *policy.Request) (*http.Response, error) { +func (b *BearerTokenPolicy) Do(req *azpolicy.Request) (*http.Response, error) { as := acquiringResourceState{ ctx: req.Raw().Context(), p: b, diff --git a/sdk/azcore/arm/runtime/policy_bearer_token_test.go b/sdk/azcore/arm/runtime/policy_bearer_token_test.go index a5c07e0f6a9a..d17aa2813b9c 100644 --- a/sdk/azcore/arm/runtime/policy_bearer_token_test.go +++ b/sdk/azcore/arm/runtime/policy_bearer_token_test.go @@ -14,9 +14,10 @@ import ( "github.com/Azure/azure-sdk-for-go/sdk/azcore" "github.com/Azure/azure-sdk-for-go/sdk/azcore/arm" + armpolicy "github.com/Azure/azure-sdk-for-go/sdk/azcore/arm/policy" "github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/pipeline" "github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/shared" - "github.com/Azure/azure-sdk-for-go/sdk/azcore/policy" + azpolicy "github.com/Azure/azure-sdk-for-go/sdk/azcore/policy" "github.com/Azure/azure-sdk-for-go/sdk/azcore/runtime" "github.com/Azure/azure-sdk-for-go/sdk/internal/mock" ) @@ -29,30 +30,30 @@ const ( ) type mockCredential struct { - getTokenImpl func(ctx context.Context, options policy.TokenRequestOptions) (*azcore.AccessToken, error) + getTokenImpl func(ctx context.Context, options azpolicy.TokenRequestOptions) (*azcore.AccessToken, error) } -func (mc mockCredential) GetToken(ctx context.Context, options policy.TokenRequestOptions) (*azcore.AccessToken, error) { +func (mc mockCredential) GetToken(ctx context.Context, options azpolicy.TokenRequestOptions) (*azcore.AccessToken, error) { if mc.getTokenImpl != nil { return mc.getTokenImpl(ctx, options) } return &azcore.AccessToken{Token: "***", ExpiresOn: time.Now().Add(time.Hour)}, nil } -func (mc mockCredential) NewAuthenticationPolicy(options AuthenticationOptions) policy.Policy { +func (mc mockCredential) NewAuthenticationPolicy() azpolicy.Policy { return mc } -func (mc mockCredential) Do(req *policy.Request) (*http.Response, error) { +func (mc mockCredential) Do(req *azpolicy.Request) (*http.Response, error) { return nil, nil } -func newTestPipeline(opts *policy.ClientOptions) pipeline.Pipeline { +func newTestPipeline(opts *azpolicy.ClientOptions) pipeline.Pipeline { return runtime.NewPipeline("testmodule", "v0.1.0", nil, nil, opts) } -func defaultTestPipeline(srv policy.Transporter, scope string) pipeline.Pipeline { - retryOpts := policy.RetryOptions{ +func defaultTestPipeline(srv azpolicy.Transporter, scope string) pipeline.Pipeline { + retryOpts := azpolicy.RetryOptions{ MaxRetryDelay: 500 * time.Millisecond, RetryDelay: time.Millisecond, } @@ -93,16 +94,16 @@ func TestBearerPolicy_CredentialFailGetToken(t *testing.T) { defer close() expectedErr := errors.New("oops") failCredential := mockCredential{} - failCredential.getTokenImpl = func(ctx context.Context, options policy.TokenRequestOptions) (*azcore.AccessToken, error) { + failCredential.getTokenImpl = func(ctx context.Context, options azpolicy.TokenRequestOptions) (*azcore.AccessToken, error) { return nil, expectedErr } - b := NewBearerTokenPolicy(failCredential, AuthenticationOptions{}) - pipeline := newTestPipeline(&policy.ClientOptions{ + b := NewBearerTokenPolicy(failCredential, nil) + pipeline := newTestPipeline(&azpolicy.ClientOptions{ Transport: srv, - Retry: policy.RetryOptions{ + Retry: azpolicy.RetryOptions{ RetryDelay: 10 * time.Millisecond, }, - PerRetryPolicies: []policy.Policy{b}, + PerRetryPolicies: []azpolicy.Policy{b}, }) req, err := runtime.NewRequest(context.Background(), http.MethodGet, srv.URL()) if err != nil { @@ -141,15 +142,15 @@ func TestBearerPolicy_GetTokenFailsNoDeadlock(t *testing.T) { srv, close := mock.NewTLSServer() defer close() srv.AppendResponse(mock.WithBody([]byte(accessTokenRespSuccess))) - retryOpts := policy.RetryOptions{ + retryOpts := azpolicy.RetryOptions{ // use a negative try timeout to trigger a deadline exceeded error causing GetToken() to fail TryTimeout: -1 * time.Nanosecond, MaxRetryDelay: 500 * time.Millisecond, RetryDelay: 50 * time.Millisecond, MaxRetries: 3, } - b := NewBearerTokenPolicy(mockCredential{}, AuthenticationOptions{}) - pipeline := newTestPipeline(&policy.ClientOptions{Transport: srv, Retry: retryOpts, PerRetryPolicies: []pipeline.Policy{b}}) + b := NewBearerTokenPolicy(mockCredential{}, nil) + pipeline := newTestPipeline(&azpolicy.ClientOptions{Transport: srv, Retry: retryOpts, PerRetryPolicies: []pipeline.Policy{b}}) req, err := runtime.NewRequest(context.Background(), http.MethodGet, srv.URL()) if err != nil { t.Fatal(err) @@ -171,18 +172,18 @@ func TestBearerTokenWithAuxiliaryTenants(t *testing.T) { srv.AppendResponse(mock.WithBody([]byte(accessTokenRespSuccess))) srv.AppendResponse(mock.WithBody([]byte(accessTokenRespSuccess))) srv.AppendResponse() - retryOpts := policy.RetryOptions{ + retryOpts := azpolicy.RetryOptions{ MaxRetryDelay: 500 * time.Millisecond, RetryDelay: 50 * time.Millisecond, } b := NewBearerTokenPolicy( mockCredential{}, - AuthenticationOptions{ + &armpolicy.BearerTokenOptions{ Scopes: []string{scope}, AuxiliaryTenants: []string{"tenant1", "tenant2", "tenant3"}, }, ) - pipeline := newTestPipeline(&policy.ClientOptions{Transport: srv, Retry: retryOpts, PerRetryPolicies: []pipeline.Policy{b}}) + pipeline := newTestPipeline(&azpolicy.ClientOptions{Transport: srv, Retry: retryOpts, PerRetryPolicies: []pipeline.Policy{b}}) req, err := runtime.NewRequest(context.Background(), http.MethodGet, srv.URL()) if err != nil { t.Fatalf("Unexpected error: %v", err) diff --git a/sdk/azcore/arm/runtime/policy_register_rp.go b/sdk/azcore/arm/runtime/policy_register_rp.go index 4aa4541a7624..f1a2a4233052 100644 --- a/sdk/azcore/arm/runtime/policy_register_rp.go +++ b/sdk/azcore/arm/runtime/policy_register_rp.go @@ -17,9 +17,10 @@ import ( "time" "github.com/Azure/azure-sdk-for-go/sdk/azcore" + armpolicy "github.com/Azure/azure-sdk-for-go/sdk/azcore/arm/policy" "github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/pipeline" "github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/shared" - "github.com/Azure/azure-sdk-for-go/sdk/azcore/policy" + azpolicy "github.com/Azure/azure-sdk-for-go/sdk/azcore/policy" "github.com/Azure/azure-sdk-for-go/sdk/azcore/runtime" "github.com/Azure/azure-sdk-for-go/sdk/internal/log" ) @@ -30,30 +31,8 @@ const ( LogRPRegistration log.Event = "RPRegistration" ) -// RegistrationOptions configures the registration policy's behavior. -// All zero-value fields will be initialized with their default values. -type RegistrationOptions struct { - policy.ClientOptions - - // MaxAttempts is the total number of times to attempt automatic registration - // in the event that an attempt fails. - // The default value is 3. - // Set to a value less than zero to disable the policy. - MaxAttempts int - - // PollingDelay is the amount of time to sleep between polling intervals. - // The default value is 15 seconds. - // A value less than zero means no delay between polling intervals (not recommended). - PollingDelay time.Duration - - // PollingDuration is the amount of time to wait before abandoning polling. - // The default valule is 5 minutes. - // NOTE: Setting this to a small value might cause the policy to prematurely fail. - PollingDuration time.Duration -} - // init sets any default values -func (r *RegistrationOptions) init() { +func setDefaults(r *armpolicy.RegistrationOptions) { if r.MaxAttempts == 0 { r.MaxAttempts = 3 } else if r.MaxAttempts < 0 { @@ -73,28 +52,28 @@ func (r *RegistrationOptions) init() { // credentials and options. The policy controls if an unregistered resource provider should // automatically be registered. See https://aka.ms/rps-not-found for more information. // Pass nil to accept the default options; this is the same as passing a zero-value options. -func NewRPRegistrationPolicy(endpoint string, cred azcore.TokenCredential, o *RegistrationOptions) policy.Policy { +func NewRPRegistrationPolicy(endpoint string, cred azcore.TokenCredential, o *armpolicy.RegistrationOptions) azpolicy.Policy { if o == nil { - o = &RegistrationOptions{} + o = &armpolicy.RegistrationOptions{} } - authPolicy := runtime.NewBearerTokenPolicy(cred, runtime.AuthenticationOptions{TokenRequest: policy.TokenRequestOptions{Scopes: []string{shared.EndpointToScope(endpoint)}}}) + authPolicy := NewBearerTokenPolicy(cred, &armpolicy.BearerTokenOptions{Scopes: []string{shared.EndpointToScope(endpoint)}}) p := &rpRegistrationPolicy{ endpoint: endpoint, pipeline: runtime.NewPipeline(shared.Module, shared.Version, nil, []pipeline.Policy{authPolicy}, &o.ClientOptions), options: *o, } // init the copy - p.options.init() + setDefaults(&p.options) return p } type rpRegistrationPolicy struct { endpoint string pipeline pipeline.Pipeline - options RegistrationOptions + options armpolicy.RegistrationOptions } -func (r *rpRegistrationPolicy) Do(req *policy.Request) (*http.Response, error) { +func (r *rpRegistrationPolicy) Do(req *azpolicy.Request) (*http.Response, error) { if r.options.MaxAttempts == 0 { // policy is disabled return req.Next() @@ -250,7 +229,7 @@ func (client *providersOperations) Get(ctx context.Context, resourceProviderName } // getCreateRequest creates the Get request. -func (client *providersOperations) getCreateRequest(ctx context.Context, resourceProviderNamespace string) (*policy.Request, error) { +func (client *providersOperations) getCreateRequest(ctx context.Context, resourceProviderNamespace string) (*azpolicy.Request, error) { urlPath := "/subscriptions/{subscriptionId}/providers/{resourceProviderNamespace}" urlPath = strings.ReplaceAll(urlPath, "{resourceProviderNamespace}", url.PathEscape(resourceProviderNamespace)) urlPath = strings.ReplaceAll(urlPath, "{subscriptionId}", url.PathEscape(client.subID)) @@ -307,7 +286,7 @@ func (client *providersOperations) Register(ctx context.Context, resourceProvide } // registerCreateRequest creates the Register request. -func (client *providersOperations) registerCreateRequest(ctx context.Context, resourceProviderNamespace string) (*policy.Request, error) { +func (client *providersOperations) registerCreateRequest(ctx context.Context, resourceProviderNamespace string) (*azpolicy.Request, error) { urlPath := "/subscriptions/{subscriptionId}/providers/{resourceProviderNamespace}/register" urlPath = strings.ReplaceAll(urlPath, "{resourceProviderNamespace}", url.PathEscape(resourceProviderNamespace)) urlPath = strings.ReplaceAll(urlPath, "{subscriptionId}", url.PathEscape(client.subID)) diff --git a/sdk/azcore/arm/runtime/policy_register_rp_test.go b/sdk/azcore/arm/runtime/policy_register_rp_test.go index d3818536777d..780762f49b32 100644 --- a/sdk/azcore/arm/runtime/policy_register_rp_test.go +++ b/sdk/azcore/arm/runtime/policy_register_rp_test.go @@ -16,9 +16,10 @@ import ( "time" "github.com/Azure/azure-sdk-for-go/sdk/azcore" + armpolicy "github.com/Azure/azure-sdk-for-go/sdk/azcore/arm/policy" "github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/pipeline" "github.com/Azure/azure-sdk-for-go/sdk/azcore/log" - "github.com/Azure/azure-sdk-for-go/sdk/azcore/policy" + azpolicy "github.com/Azure/azure-sdk-for-go/sdk/azcore/policy" "github.com/Azure/azure-sdk-for-go/sdk/azcore/runtime" "github.com/Azure/azure-sdk-for-go/sdk/internal/mock" ) @@ -57,11 +58,11 @@ const requestEndpoint = "/subscriptions/00000000-0000-0000-0000-000000000000/res func newTestRPRegistrationPipeline(srv *mock.Server) pipeline.Pipeline { opts := azcore.ClientOptions{Transport: srv} rp := NewRPRegistrationPolicy(srv.URL(), mockTokenCred{}, testRPRegistrationOptions(srv)) - return runtime.NewPipeline("test", "v0.1.0", []policy.Policy{rp}, nil, &opts) + return runtime.NewPipeline("test", "v0.1.0", []azpolicy.Policy{rp}, nil, &opts) } -func testRPRegistrationOptions(t policy.Transporter) *RegistrationOptions { - def := RegistrationOptions{} +func testRPRegistrationOptions(t azpolicy.Transporter) *armpolicy.RegistrationOptions { + def := armpolicy.RegistrationOptions{} def.Transport = t def.PollingDelay = 100 * time.Millisecond def.PollingDuration = 1 * time.Second @@ -70,13 +71,13 @@ func testRPRegistrationOptions(t policy.Transporter) *RegistrationOptions { type mockTokenCred struct{} -func (mockTokenCred) NewAuthenticationPolicy(runtime.AuthenticationOptions) policy.Policy { - return pipeline.PolicyFunc(func(req *policy.Request) (*http.Response, error) { +func (mockTokenCred) NewAuthenticationPolicy() azpolicy.Policy { + return pipeline.PolicyFunc(func(req *azpolicy.Request) (*http.Response, error) { return req.Next() }) } -func (mockTokenCred) GetToken(context.Context, policy.TokenRequestOptions) (*azcore.AccessToken, error) { +func (mockTokenCred) GetToken(context.Context, azpolicy.TokenRequestOptions) (*azcore.AccessToken, error) { return &azcore.AccessToken{ Token: "abc123", ExpiresOn: time.Now().Add(1 * time.Hour), @@ -294,7 +295,7 @@ func TestRPRegistrationPolicyCanCancel(t *testing.T) { srv.AppendResponse(mock.WithStatusCode(http.StatusConflict), mock.WithBody([]byte(rpUnregisteredResp))) // polling responses to Register() and Get(), in progress but slow so we have time to cancel srv.RepeatResponse(10, mock.WithStatusCode(http.StatusOK), mock.WithBody([]byte(rpRegisteringResp)), mock.WithSlowResponse(300*time.Millisecond)) - opts := RegistrationOptions{} + opts := armpolicy.RegistrationOptions{} opts.Transport = srv pl := newTestRPRegistrationPipeline(srv) // log only RP registration @@ -317,7 +318,7 @@ func TestRPRegistrationPolicyCanCancel(t *testing.T) { go func() { defer wg.Done() // create request and start pipeline - var req *policy.Request + var req *azpolicy.Request req, err = runtime.NewRequest(ctx, http.MethodGet, runtime.JoinPaths(srv.URL(), requestEndpoint)) if err != nil { return diff --git a/sdk/azcore/policy/policy.go b/sdk/azcore/policy/policy.go index f6968794defd..d739109c323b 100644 --- a/sdk/azcore/policy/policy.go +++ b/sdk/azcore/policy/policy.go @@ -106,6 +106,11 @@ type TokenRequestOptions struct { TenantID string } +// BearerTokenOptions configures the bearer token policy's behavior. +type BearerTokenOptions struct { + // placeholder for future options +} + // WithHTTPHeader adds the specified http.Header to the parent context. // Use this to specify custom HTTP headers at the API-call level. // Any overlapping headers will have their values replaced with the values specified here. diff --git a/sdk/azcore/runtime/policy_bearer_token.go b/sdk/azcore/runtime/policy_bearer_token.go index 0dc80e914659..d5ed61e14864 100644 --- a/sdk/azcore/runtime/policy_bearer_token.go +++ b/sdk/azcore/runtime/policy_bearer_token.go @@ -17,8 +17,8 @@ type BearerTokenPolicy struct { // mainResource is the resource to be retreived using the tenant specified in the credential mainResource *shared.ExpiringResource // the following fields are read-only - cred azcore.TokenCredential - options policy.TokenRequestOptions + cred azcore.TokenCredential + scopes []string } type acquiringResourceState struct { @@ -30,7 +30,7 @@ type acquiringResourceState struct { // thread/goroutine at a time ever calls this function func acquire(state interface{}) (newResource interface{}, newExpiration time.Time, err error) { s := state.(acquiringResourceState) - tk, err := s.p.cred.GetToken(s.req.Raw().Context(), s.p.options) + tk, err := s.p.cred.GetToken(s.req.Raw().Context(), policy.TokenRequestOptions{Scopes: s.p.scopes}) if err != nil { return nil, time.Time{}, err } @@ -39,11 +39,12 @@ func acquire(state interface{}) (newResource interface{}, newExpiration time.Tim // NewBearerTokenPolicy creates a policy object that authorizes requests with bearer tokens. // cred: an azcore.TokenCredential implementation such as a credential object from azidentity +// scopes: the list of permission scopes required for the token. // opts: optional settings. Pass nil to accept default values; this is the same as passing a zero-value options. -func NewBearerTokenPolicy(cred azcore.TokenCredential, opts AuthenticationOptions) *BearerTokenPolicy { +func NewBearerTokenPolicy(cred azcore.TokenCredential, scopes []string, opts *policy.BearerTokenOptions) *BearerTokenPolicy { return &BearerTokenPolicy{ cred: cred, - options: opts.TokenRequest, + scopes: scopes, mainResource: shared.NewExpiringResource(acquire), } } diff --git a/sdk/azcore/runtime/policy_bearer_token_test.go b/sdk/azcore/runtime/policy_bearer_token_test.go index 39d13682ebfc..02f9dd3a74e7 100644 --- a/sdk/azcore/runtime/policy_bearer_token_test.go +++ b/sdk/azcore/runtime/policy_bearer_token_test.go @@ -36,7 +36,7 @@ func (mc mockCredential) GetToken(ctx context.Context, options policy.TokenReque return &azcore.AccessToken{Token: "***", ExpiresOn: time.Now().Add(time.Hour)}, nil } -func (mc mockCredential) NewAuthenticationPolicy(options AuthenticationOptions) policy.Policy { +func (mc mockCredential) NewAuthenticationPolicy() policy.Policy { return mc } @@ -49,10 +49,7 @@ func defaultTestPipeline(srv policy.Transporter, scope string) Pipeline { MaxRetryDelay: 500 * time.Millisecond, RetryDelay: time.Millisecond, } - b := NewBearerTokenPolicy( - mockCredential{}, - AuthenticationOptions{TokenRequest: policy.TokenRequestOptions{Scopes: []string{scope}}}, - ) + b := NewBearerTokenPolicy(mockCredential{}, []string{scope}, nil) return NewPipeline( "testmodule", "v0.1.0", @@ -90,7 +87,7 @@ func TestBearerPolicy_CredentialFailGetToken(t *testing.T) { failCredential.getTokenImpl = func(ctx context.Context, options policy.TokenRequestOptions) (*azcore.AccessToken, error) { return nil, expectedErr } - b := NewBearerTokenPolicy(failCredential, AuthenticationOptions{}) + b := NewBearerTokenPolicy(failCredential, nil, nil) pipeline := newTestPipeline(&policy.ClientOptions{ Transport: srv, Retry: policy.RetryOptions{ @@ -142,7 +139,7 @@ func TestBearerPolicy_GetTokenFailsNoDeadlock(t *testing.T) { RetryDelay: 50 * time.Millisecond, MaxRetries: 3, } - b := NewBearerTokenPolicy(mockCredential{}, AuthenticationOptions{}) + b := NewBearerTokenPolicy(mockCredential{}, nil, nil) pipeline := newTestPipeline(&policy.ClientOptions{Transport: srv, Retry: retryOpts, PerRetryPolicies: []pipeline.Policy{b}}) req, err := NewRequest(context.Background(), http.MethodGet, srv.URL()) if err != nil { diff --git a/sdk/azcore/runtime/transport_default_http_client.go b/sdk/azcore/runtime/transport_default_http_client.go index 3fe1fa2435e5..f7f3ca9c14ed 100644 --- a/sdk/azcore/runtime/transport_default_http_client.go +++ b/sdk/azcore/runtime/transport_default_http_client.go @@ -11,8 +11,6 @@ import ( "net" "net/http" "time" - - "github.com/Azure/azure-sdk-for-go/sdk/azcore/policy" ) var defaultHTTPClient *http.Client @@ -37,11 +35,3 @@ func init() { Transport: defaultTransport, } } - -// AuthenticationOptions contains various options used to create a credential policy. -type AuthenticationOptions struct { - // TokenRequest is a TokenRequestOptions that includes a scopes field which contains - // the list of OAuth2 authentication scopes used when requesting a token. - // This field is ignored for other forms of authentication (e.g. shared key). - TokenRequest policy.TokenRequestOptions -}