diff --git a/sdk/azcore/CHANGELOG.md b/sdk/azcore/CHANGELOG.md index 65b8310d8aa1..6846084b5125 100644 --- a/sdk/azcore/CHANGELOG.md +++ b/sdk/azcore/CHANGELOG.md @@ -17,6 +17,7 @@ * Renamed `runtime.PageProcessor` to `runtime.PagingHandler` * The `arm/runtime.ProviderRepsonse` and `arm/runtime.Provider` types are no longer exported. * Renamed `NewRequestIdPolicy()` to `NewRequestIDPolicy()` +* `TokenCredential.GetToken` now returns `AccessToken` by value. ### Bugs Fixed * When per-try timeouts are enabled, only cancel the context after the body has been read and closed. diff --git a/sdk/azcore/arm/runtime/pipeline_test.go b/sdk/azcore/arm/runtime/pipeline_test.go index 056a1c3a2d61..c3758f5ce1ec 100644 --- a/sdk/azcore/arm/runtime/pipeline_test.go +++ b/sdk/azcore/arm/runtime/pipeline_test.go @@ -179,7 +179,7 @@ func TestPipelineAudience(t *testing.T) { t.Fatal("unexpected audience " + audience) } getTokenCalled := false - cred := mockCredential{getTokenImpl: func(ctx context.Context, options policy.TokenRequestOptions) (*azcore.AccessToken, error) { + cred := mockCredential{getTokenImpl: func(ctx context.Context, options policy.TokenRequestOptions) (azcore.AccessToken, error) { getTokenCalled = true if n := len(options.Scopes); n != 1 { t.Fatalf("expected 1 scope, got %d", n) @@ -187,7 +187,7 @@ func TestPipelineAudience(t *testing.T) { if options.Scopes[0] != audience+"/.default" { t.Fatalf(`unexpected scope "%s"`, options.Scopes[0]) } - return &azcore.AccessToken{Token: "...", ExpiresOn: time.Now().Add(time.Hour)}, nil + return azcore.AccessToken{Token: "...", ExpiresOn: time.Now().Add(time.Hour)}, nil }} req, err := azruntime.NewRequest(context.Background(), http.MethodGet, srv.URL()) if err != nil { diff --git a/sdk/azcore/arm/runtime/policy_bearer_token.go b/sdk/azcore/arm/runtime/policy_bearer_token.go index 57778501548e..f77d2990056e 100644 --- a/sdk/azcore/arm/runtime/policy_bearer_token.go +++ b/sdk/azcore/arm/runtime/policy_bearer_token.go @@ -24,10 +24,10 @@ type acquiringResourceState struct { // acquire acquires or updates the resource; only one // thread/goroutine at a time ever calls this function -func acquire(state acquiringResourceState) (newResource *azcore.AccessToken, newExpiration time.Time, err error) { +func acquire(state acquiringResourceState) (newResource azcore.AccessToken, newExpiration time.Time, err error) { tk, err := state.p.cred.GetToken(state.ctx, azpolicy.TokenRequestOptions{Scopes: state.p.options.Scopes}) if err != nil { - return nil, time.Time{}, err + return azcore.AccessToken{}, time.Time{}, err } return tk, tk.ExpiresOn, nil } @@ -35,9 +35,9 @@ func acquire(state acquiringResourceState) (newResource *azcore.AccessToken, new // BearerTokenPolicy authorizes requests with bearer tokens acquired from a TokenCredential. type BearerTokenPolicy struct { // mainResource is the resource to be retreived using the tenant specified in the credential - mainResource *shared.ExpiringResource[*azcore.AccessToken, acquiringResourceState] + mainResource *shared.ExpiringResource[azcore.AccessToken, acquiringResourceState] // auxResources are additional resources that are required for cross-tenant applications - auxResources map[string]*shared.ExpiringResource[*azcore.AccessToken, acquiringResourceState] + auxResources map[string]*shared.ExpiringResource[azcore.AccessToken, acquiringResourceState] // the following fields are read-only cred azcore.TokenCredential options armpolicy.BearerTokenOptions diff --git a/sdk/azcore/arm/runtime/policy_bearer_token_test.go b/sdk/azcore/arm/runtime/policy_bearer_token_test.go index 559ba150a61b..8e9657a41e53 100644 --- a/sdk/azcore/arm/runtime/policy_bearer_token_test.go +++ b/sdk/azcore/arm/runtime/policy_bearer_token_test.go @@ -16,7 +16,6 @@ import ( "github.com/Azure/azure-sdk-for-go/sdk/azcore/arm" armpolicy "github.com/Azure/azure-sdk-for-go/sdk/azcore/arm/policy" "github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/shared" - "github.com/Azure/azure-sdk-for-go/sdk/azcore/policy" azpolicy "github.com/Azure/azure-sdk-for-go/sdk/azcore/policy" "github.com/Azure/azure-sdk-for-go/sdk/azcore/runtime" "github.com/Azure/azure-sdk-for-go/sdk/internal/mock" @@ -30,18 +29,14 @@ const ( ) type mockCredential struct { - getTokenImpl func(ctx context.Context, options policy.TokenRequestOptions) (*azcore.AccessToken, error) + getTokenImpl func(ctx context.Context, options azpolicy.TokenRequestOptions) (azcore.AccessToken, error) } -func (mc mockCredential) GetToken(ctx context.Context, options policy.TokenRequestOptions) (*azcore.AccessToken, error) { +func (mc mockCredential) GetToken(ctx context.Context, options azpolicy.TokenRequestOptions) (azcore.AccessToken, error) { if mc.getTokenImpl != nil { return mc.getTokenImpl(ctx, options) } - return &azcore.AccessToken{Token: "***", ExpiresOn: time.Now().Add(time.Hour)}, nil -} - -func (mc mockCredential) NewAuthenticationPolicy() azpolicy.Policy { - return mc + return azcore.AccessToken{Token: "***", ExpiresOn: time.Now().Add(time.Hour)}, nil } func (mc mockCredential) Do(req *azpolicy.Request) (*http.Response, error) { @@ -98,8 +93,8 @@ func TestBearerPolicy_CredentialFailGetToken(t *testing.T) { defer close() expectedErr := errors.New("oops") failCredential := mockCredential{} - failCredential.getTokenImpl = func(ctx context.Context, options policy.TokenRequestOptions) (*azcore.AccessToken, error) { - return nil, expectedErr + failCredential.getTokenImpl = func(ctx context.Context, options azpolicy.TokenRequestOptions) (azcore.AccessToken, error) { + return azcore.AccessToken{}, expectedErr } b := NewBearerTokenPolicy(failCredential, nil) pipeline := newTestPipeline(&azpolicy.ClientOptions{ diff --git a/sdk/azcore/arm/runtime/policy_register_rp_test.go b/sdk/azcore/arm/runtime/policy_register_rp_test.go index 43b47802d14c..b2f082820c27 100644 --- a/sdk/azcore/arm/runtime/policy_register_rp_test.go +++ b/sdk/azcore/arm/runtime/policy_register_rp_test.go @@ -19,7 +19,6 @@ import ( armpolicy "github.com/Azure/azure-sdk-for-go/sdk/azcore/arm/policy" "github.com/Azure/azure-sdk-for-go/sdk/azcore/cloud" "github.com/Azure/azure-sdk-for-go/sdk/azcore/log" - "github.com/Azure/azure-sdk-for-go/sdk/azcore/policy" azpolicy "github.com/Azure/azure-sdk-for-go/sdk/azcore/policy" "github.com/Azure/azure-sdk-for-go/sdk/azcore/runtime" "github.com/Azure/azure-sdk-for-go/sdk/internal/mock" @@ -395,7 +394,7 @@ func TestRPRegistrationPolicyAudience(t *testing.T) { }, } getTokenCalled := false - cred := mockCredential{getTokenImpl: func(ctx context.Context, options policy.TokenRequestOptions) (*azcore.AccessToken, error) { + cred := mockCredential{getTokenImpl: func(ctx context.Context, options azpolicy.TokenRequestOptions) (azcore.AccessToken, error) { getTokenCalled = true if n := len(options.Scopes); n != 1 { t.Fatalf("expected 1 scope, got %d", n) @@ -403,7 +402,7 @@ func TestRPRegistrationPolicyAudience(t *testing.T) { if options.Scopes[0] != audience+"/.default" { t.Fatalf(`unexpected scope "%s"`, options.Scopes[0]) } - return &azcore.AccessToken{Token: "...", ExpiresOn: time.Now().Add(time.Hour)}, nil + return azcore.AccessToken{Token: "...", ExpiresOn: time.Now().Add(time.Hour)}, nil }} opts := azpolicy.ClientOptions{Cloud: conf, Transport: srv} rp, err := NewRPRegistrationPolicy(cred, &armpolicy.RegistrationOptions{ClientOptions: opts}) diff --git a/sdk/azcore/core.go b/sdk/azcore/core.go index b188b7f0e954..f9fb23422dfd 100644 --- a/sdk/azcore/core.go +++ b/sdk/azcore/core.go @@ -24,7 +24,7 @@ type AccessToken struct { // TokenCredential represents a credential capable of providing an OAuth token. type TokenCredential interface { // GetToken requests an access token for the specified set of scopes. - GetToken(ctx context.Context, options policy.TokenRequestOptions) (*AccessToken, error) + GetToken(ctx context.Context, options policy.TokenRequestOptions) (AccessToken, error) } // holds sentinel values used to send nulls diff --git a/sdk/azcore/runtime/policy_bearer_token.go b/sdk/azcore/runtime/policy_bearer_token.go index 3cfbff363029..cc9c8244edc1 100644 --- a/sdk/azcore/runtime/policy_bearer_token.go +++ b/sdk/azcore/runtime/policy_bearer_token.go @@ -15,7 +15,7 @@ import ( // BearerTokenPolicy authorizes requests with bearer tokens acquired from a TokenCredential. type BearerTokenPolicy struct { // mainResource is the resource to be retreived using the tenant specified in the credential - mainResource *shared.ExpiringResource[*azcore.AccessToken, acquiringResourceState] + mainResource *shared.ExpiringResource[azcore.AccessToken, acquiringResourceState] // the following fields are read-only cred azcore.TokenCredential scopes []string @@ -28,10 +28,10 @@ type acquiringResourceState struct { // acquire acquires or updates the resource; only one // thread/goroutine at a time ever calls this function -func acquire(state acquiringResourceState) (newResource *azcore.AccessToken, newExpiration time.Time, err error) { +func acquire(state acquiringResourceState) (newResource azcore.AccessToken, newExpiration time.Time, err error) { tk, err := state.p.cred.GetToken(state.req.Raw().Context(), policy.TokenRequestOptions{Scopes: state.p.scopes}) if err != nil { - return nil, time.Time{}, err + return azcore.AccessToken{}, time.Time{}, err } return tk, tk.ExpiresOn, nil } diff --git a/sdk/azcore/runtime/policy_bearer_token_test.go b/sdk/azcore/runtime/policy_bearer_token_test.go index 78000a6f4d77..4d59f15e4729 100644 --- a/sdk/azcore/runtime/policy_bearer_token_test.go +++ b/sdk/azcore/runtime/policy_bearer_token_test.go @@ -25,18 +25,14 @@ const ( ) type mockCredential struct { - getTokenImpl func(ctx context.Context, options policy.TokenRequestOptions) (*azcore.AccessToken, error) + getTokenImpl func(ctx context.Context, options policy.TokenRequestOptions) (azcore.AccessToken, error) } -func (mc mockCredential) GetToken(ctx context.Context, options policy.TokenRequestOptions) (*azcore.AccessToken, error) { +func (mc mockCredential) GetToken(ctx context.Context, options policy.TokenRequestOptions) (azcore.AccessToken, error) { if mc.getTokenImpl != nil { return mc.getTokenImpl(ctx, options) } - return &azcore.AccessToken{Token: "***", ExpiresOn: time.Now().Add(time.Hour)}, nil -} - -func (mc mockCredential) NewAuthenticationPolicy() policy.Policy { - return mc + return azcore.AccessToken{Token: "***", ExpiresOn: time.Now().Add(time.Hour)}, nil } func (mc mockCredential) Do(req *policy.Request) (*http.Response, error) { @@ -82,8 +78,8 @@ func TestBearerPolicy_CredentialFailGetToken(t *testing.T) { defer close() expectedErr := errors.New("oops") failCredential := mockCredential{} - failCredential.getTokenImpl = func(ctx context.Context, options policy.TokenRequestOptions) (*azcore.AccessToken, error) { - return nil, expectedErr + failCredential.getTokenImpl = func(ctx context.Context, options policy.TokenRequestOptions) (azcore.AccessToken, error) { + return azcore.AccessToken{}, expectedErr } b := NewBearerTokenPolicy(failCredential, nil, nil) pipeline := newTestPipeline(&policy.ClientOptions{