Skip to content

Commit

Permalink
Return AccessToken by value (Azure#17836)
Browse files Browse the repository at this point in the history
* Return AccessToken by value

Returning by pointer is cruft from an early design and not necessary.

* remove NewAuthenticationPolicy methods
  • Loading branch information
jhendrixMSFT authored May 6, 2022
1 parent f871fa8 commit a8723e0
Show file tree
Hide file tree
Showing 8 changed files with 23 additions and 32 deletions.
1 change: 1 addition & 0 deletions sdk/azcore/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
* Renamed `runtime.PageProcessor` to `runtime.PagingHandler`
* The `arm/runtime.ProviderRepsonse` and `arm/runtime.Provider` types are no longer exported.
* Renamed `NewRequestIdPolicy()` to `NewRequestIDPolicy()`
* `TokenCredential.GetToken` now returns `AccessToken` by value.

### Bugs Fixed
* When per-try timeouts are enabled, only cancel the context after the body has been read and closed.
Expand Down
4 changes: 2 additions & 2 deletions sdk/azcore/arm/runtime/pipeline_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -179,15 +179,15 @@ func TestPipelineAudience(t *testing.T) {
t.Fatal("unexpected audience " + audience)
}
getTokenCalled := false
cred := mockCredential{getTokenImpl: func(ctx context.Context, options policy.TokenRequestOptions) (*azcore.AccessToken, error) {
cred := mockCredential{getTokenImpl: func(ctx context.Context, options policy.TokenRequestOptions) (azcore.AccessToken, error) {
getTokenCalled = true
if n := len(options.Scopes); n != 1 {
t.Fatalf("expected 1 scope, got %d", n)
}
if options.Scopes[0] != audience+"/.default" {
t.Fatalf(`unexpected scope "%s"`, options.Scopes[0])
}
return &azcore.AccessToken{Token: "...", ExpiresOn: time.Now().Add(time.Hour)}, nil
return azcore.AccessToken{Token: "...", ExpiresOn: time.Now().Add(time.Hour)}, nil
}}
req, err := azruntime.NewRequest(context.Background(), http.MethodGet, srv.URL())
if err != nil {
Expand Down
8 changes: 4 additions & 4 deletions sdk/azcore/arm/runtime/policy_bearer_token.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,20 +24,20 @@ type acquiringResourceState struct {

// acquire acquires or updates the resource; only one
// thread/goroutine at a time ever calls this function
func acquire(state acquiringResourceState) (newResource *azcore.AccessToken, newExpiration time.Time, err error) {
func acquire(state acquiringResourceState) (newResource azcore.AccessToken, newExpiration time.Time, err error) {
tk, err := state.p.cred.GetToken(state.ctx, azpolicy.TokenRequestOptions{Scopes: state.p.options.Scopes})
if err != nil {
return nil, time.Time{}, err
return azcore.AccessToken{}, time.Time{}, err
}
return tk, tk.ExpiresOn, nil
}

// BearerTokenPolicy authorizes requests with bearer tokens acquired from a TokenCredential.
type BearerTokenPolicy struct {
// mainResource is the resource to be retreived using the tenant specified in the credential
mainResource *shared.ExpiringResource[*azcore.AccessToken, acquiringResourceState]
mainResource *shared.ExpiringResource[azcore.AccessToken, acquiringResourceState]
// auxResources are additional resources that are required for cross-tenant applications
auxResources map[string]*shared.ExpiringResource[*azcore.AccessToken, acquiringResourceState]
auxResources map[string]*shared.ExpiringResource[azcore.AccessToken, acquiringResourceState]
// the following fields are read-only
cred azcore.TokenCredential
options armpolicy.BearerTokenOptions
Expand Down
15 changes: 5 additions & 10 deletions sdk/azcore/arm/runtime/policy_bearer_token_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ import (
"github.com/Azure/azure-sdk-for-go/sdk/azcore/arm"
armpolicy "github.com/Azure/azure-sdk-for-go/sdk/azcore/arm/policy"
"github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/shared"
"github.com/Azure/azure-sdk-for-go/sdk/azcore/policy"
azpolicy "github.com/Azure/azure-sdk-for-go/sdk/azcore/policy"
"github.com/Azure/azure-sdk-for-go/sdk/azcore/runtime"
"github.com/Azure/azure-sdk-for-go/sdk/internal/mock"
Expand All @@ -30,18 +29,14 @@ const (
)

type mockCredential struct {
getTokenImpl func(ctx context.Context, options policy.TokenRequestOptions) (*azcore.AccessToken, error)
getTokenImpl func(ctx context.Context, options azpolicy.TokenRequestOptions) (azcore.AccessToken, error)
}

func (mc mockCredential) GetToken(ctx context.Context, options policy.TokenRequestOptions) (*azcore.AccessToken, error) {
func (mc mockCredential) GetToken(ctx context.Context, options azpolicy.TokenRequestOptions) (azcore.AccessToken, error) {
if mc.getTokenImpl != nil {
return mc.getTokenImpl(ctx, options)
}
return &azcore.AccessToken{Token: "***", ExpiresOn: time.Now().Add(time.Hour)}, nil
}

func (mc mockCredential) NewAuthenticationPolicy() azpolicy.Policy {
return mc
return azcore.AccessToken{Token: "***", ExpiresOn: time.Now().Add(time.Hour)}, nil
}

func (mc mockCredential) Do(req *azpolicy.Request) (*http.Response, error) {
Expand Down Expand Up @@ -98,8 +93,8 @@ func TestBearerPolicy_CredentialFailGetToken(t *testing.T) {
defer close()
expectedErr := errors.New("oops")
failCredential := mockCredential{}
failCredential.getTokenImpl = func(ctx context.Context, options policy.TokenRequestOptions) (*azcore.AccessToken, error) {
return nil, expectedErr
failCredential.getTokenImpl = func(ctx context.Context, options azpolicy.TokenRequestOptions) (azcore.AccessToken, error) {
return azcore.AccessToken{}, expectedErr
}
b := NewBearerTokenPolicy(failCredential, nil)
pipeline := newTestPipeline(&azpolicy.ClientOptions{
Expand Down
5 changes: 2 additions & 3 deletions sdk/azcore/arm/runtime/policy_register_rp_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ import (
armpolicy "github.com/Azure/azure-sdk-for-go/sdk/azcore/arm/policy"
"github.com/Azure/azure-sdk-for-go/sdk/azcore/cloud"
"github.com/Azure/azure-sdk-for-go/sdk/azcore/log"
"github.com/Azure/azure-sdk-for-go/sdk/azcore/policy"
azpolicy "github.com/Azure/azure-sdk-for-go/sdk/azcore/policy"
"github.com/Azure/azure-sdk-for-go/sdk/azcore/runtime"
"github.com/Azure/azure-sdk-for-go/sdk/internal/mock"
Expand Down Expand Up @@ -395,15 +394,15 @@ func TestRPRegistrationPolicyAudience(t *testing.T) {
},
}
getTokenCalled := false
cred := mockCredential{getTokenImpl: func(ctx context.Context, options policy.TokenRequestOptions) (*azcore.AccessToken, error) {
cred := mockCredential{getTokenImpl: func(ctx context.Context, options azpolicy.TokenRequestOptions) (azcore.AccessToken, error) {
getTokenCalled = true
if n := len(options.Scopes); n != 1 {
t.Fatalf("expected 1 scope, got %d", n)
}
if options.Scopes[0] != audience+"/.default" {
t.Fatalf(`unexpected scope "%s"`, options.Scopes[0])
}
return &azcore.AccessToken{Token: "...", ExpiresOn: time.Now().Add(time.Hour)}, nil
return azcore.AccessToken{Token: "...", ExpiresOn: time.Now().Add(time.Hour)}, nil
}}
opts := azpolicy.ClientOptions{Cloud: conf, Transport: srv}
rp, err := NewRPRegistrationPolicy(cred, &armpolicy.RegistrationOptions{ClientOptions: opts})
Expand Down
2 changes: 1 addition & 1 deletion sdk/azcore/core.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ type AccessToken struct {
// TokenCredential represents a credential capable of providing an OAuth token.
type TokenCredential interface {
// GetToken requests an access token for the specified set of scopes.
GetToken(ctx context.Context, options policy.TokenRequestOptions) (*AccessToken, error)
GetToken(ctx context.Context, options policy.TokenRequestOptions) (AccessToken, error)
}

// holds sentinel values used to send nulls
Expand Down
6 changes: 3 additions & 3 deletions sdk/azcore/runtime/policy_bearer_token.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ import (
// BearerTokenPolicy authorizes requests with bearer tokens acquired from a TokenCredential.
type BearerTokenPolicy struct {
// mainResource is the resource to be retreived using the tenant specified in the credential
mainResource *shared.ExpiringResource[*azcore.AccessToken, acquiringResourceState]
mainResource *shared.ExpiringResource[azcore.AccessToken, acquiringResourceState]
// the following fields are read-only
cred azcore.TokenCredential
scopes []string
Expand All @@ -28,10 +28,10 @@ type acquiringResourceState struct {

// acquire acquires or updates the resource; only one
// thread/goroutine at a time ever calls this function
func acquire(state acquiringResourceState) (newResource *azcore.AccessToken, newExpiration time.Time, err error) {
func acquire(state acquiringResourceState) (newResource azcore.AccessToken, newExpiration time.Time, err error) {
tk, err := state.p.cred.GetToken(state.req.Raw().Context(), policy.TokenRequestOptions{Scopes: state.p.scopes})
if err != nil {
return nil, time.Time{}, err
return azcore.AccessToken{}, time.Time{}, err
}
return tk, tk.ExpiresOn, nil
}
Expand Down
14 changes: 5 additions & 9 deletions sdk/azcore/runtime/policy_bearer_token_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,18 +25,14 @@ const (
)

type mockCredential struct {
getTokenImpl func(ctx context.Context, options policy.TokenRequestOptions) (*azcore.AccessToken, error)
getTokenImpl func(ctx context.Context, options policy.TokenRequestOptions) (azcore.AccessToken, error)
}

func (mc mockCredential) GetToken(ctx context.Context, options policy.TokenRequestOptions) (*azcore.AccessToken, error) {
func (mc mockCredential) GetToken(ctx context.Context, options policy.TokenRequestOptions) (azcore.AccessToken, error) {
if mc.getTokenImpl != nil {
return mc.getTokenImpl(ctx, options)
}
return &azcore.AccessToken{Token: "***", ExpiresOn: time.Now().Add(time.Hour)}, nil
}

func (mc mockCredential) NewAuthenticationPolicy() policy.Policy {
return mc
return azcore.AccessToken{Token: "***", ExpiresOn: time.Now().Add(time.Hour)}, nil
}

func (mc mockCredential) Do(req *policy.Request) (*http.Response, error) {
Expand Down Expand Up @@ -82,8 +78,8 @@ func TestBearerPolicy_CredentialFailGetToken(t *testing.T) {
defer close()
expectedErr := errors.New("oops")
failCredential := mockCredential{}
failCredential.getTokenImpl = func(ctx context.Context, options policy.TokenRequestOptions) (*azcore.AccessToken, error) {
return nil, expectedErr
failCredential.getTokenImpl = func(ctx context.Context, options policy.TokenRequestOptions) (azcore.AccessToken, error) {
return azcore.AccessToken{}, expectedErr
}
b := NewBearerTokenPolicy(failCredential, nil, nil)
pipeline := newTestPipeline(&policy.ClientOptions{
Expand Down

0 comments on commit a8723e0

Please sign in to comment.