diff --git a/sdk/azcore/CHANGELOG.md b/sdk/azcore/CHANGELOG.md index 56fd3d94be99..edf6588870e8 100644 --- a/sdk/azcore/CHANGELOG.md +++ b/sdk/azcore/CHANGELOG.md @@ -1,13 +1,36 @@ # Release History +## v0.19.0 + +### Breaking Changes +* Split content out of `azcore` into various packages. The intent is to separate content based on its usage (common, uncommon, SDK authors). + * `azcore` has all core functionality. + * `log` contains facilities for configuring in-box logging. + * `policy` is used for configuring pipeline options and creating custom pipeline policies. + * `runtime` contains various helpers used by SDK authors and generated content. + * `streaming` has helpers for streaming IO operations. +* `NewTelemetryPolicy()` now requires module and version parameters and the `Value` option has been removed. + * As a result, the `Request.Telemetry()` method has been removed. +* The telemetry policy now includes the SDK prefix `azsdk-go-` so callers no longer need to provide it. +* The `*http.Request` in `runtime.Request` is no longer anonymously embedded. Use the `Raw()` method to access it. +* The `UserAgent` and `Version` constants have been made internal, `Module` and `Version` respectively. + +### Bug Fixes +* Fixed an issue in the retry policy where the request body could be overwritten after a rewind. + +### Other Changes +* Moved modules `armcore` and `to` content into `arm` and `to` packages respectively. + * The `Pipeline()` method on `armcore.Connection` has been replaced by `NewPipeline()` in `arm.Connection`. It takes module and version parameters used by the telemetry policy. +* Poller logic has been consolidated across ARM and core implementations. + * This required some changes to the internal interfaces for core pollers. +* The core poller types have been improved, including more logging and test coverage. + ## v0.18.1 ### Features Added * Adds an `ETag` type for comparing etags and handling etags on requests * Simplifies the `requestBodyProgess` and `responseBodyProgress` into a single `progress` object -### Breaking Changes - ### Bugs Fixed * `JoinPaths` will preserve query parameters encoded in the `root` url. diff --git a/sdk/azcore/arm/connection.go b/sdk/azcore/arm/connection.go new file mode 100644 index 000000000000..4a7c4f530ed3 --- /dev/null +++ b/sdk/azcore/arm/connection.go @@ -0,0 +1,120 @@ +//go:build go1.16 +// +build go1.16 + +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package arm + +import ( + "github.com/Azure/azure-sdk-for-go/sdk/azcore" + armruntime "github.com/Azure/azure-sdk-for-go/sdk/azcore/arm/runtime" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/pipeline" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/shared" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/policy" + azruntime "github.com/Azure/azure-sdk-for-go/sdk/azcore/runtime" +) + +const ( + // AzureChina is the Azure Resource Manager China cloud endpoint. + AzureChina = "https://management.chinacloudapi.cn/" + // AzureGermany is the Azure Resource Manager Germany cloud endpoint. + AzureGermany = "https://management.microsoftazure.de/" + // AzureGovernment is the Azure Resource Manager US government cloud endpoint. + AzureGovernment = "https://management.usgovcloudapi.net/" + // AzurePublicCloud is the Azure Resource Manager public cloud endpoint. + AzurePublicCloud = "https://management.azure.com/" +) + +// ConnectionOptions contains configuration settings for the connection's pipeline. +// All zero-value fields will be initialized with their default values. +type ConnectionOptions struct { + // AuxiliaryTenants contains a list of additional tenants to be used to authenticate + // across multiple tenants. + AuxiliaryTenants []string + + // HTTPClient sets the transport for making HTTP requests. + HTTPClient policy.Transporter + + // Retry configures the built-in retry policy behavior. + Retry policy.RetryOptions + + // Telemetry configures the built-in telemetry policy behavior. + Telemetry policy.TelemetryOptions + + // Logging configures the built-in logging policy behavior. + Logging policy.LogOptions + + // DisableRPRegistration disables the auto-RP registration policy. + // The default value is false. + DisableRPRegistration bool + + // PerCallPolicies contains custom policies to inject into the pipeline. + // Each policy is executed once per request. + PerCallPolicies []policy.Policy + + // PerRetryPolicies contains custom policies to inject into the pipeline. + // Each policy is executed once per request, and for each retry request. + PerRetryPolicies []policy.Policy +} + +// Connection is a connection to an Azure Resource Manager endpoint. +// It contains the base ARM endpoint and a pipeline for making requests. +type Connection struct { + ep string + cred azcore.TokenCredential + opt ConnectionOptions +} + +// NewDefaultConnection creates an instance of the Connection type using the AzurePublicCloud. +// Pass nil to accept the default options; this is the same as passing a zero-value options. +func NewDefaultConnection(cred azcore.TokenCredential, options *ConnectionOptions) *Connection { + return NewConnection(AzurePublicCloud, cred, options) +} + +// NewConnection creates an instance of the Connection type with the specified endpoint. +// Use this when connecting to clouds other than the Azure public cloud (stack/sovereign clouds). +// Pass nil to accept the default options; this is the same as passing a zero-value options. +func NewConnection(endpoint string, cred azcore.TokenCredential, options *ConnectionOptions) *Connection { + if options == nil { + options = &ConnectionOptions{} + } + return &Connection{ep: endpoint, cred: cred, opt: *options} +} + +// Endpoint returns the connection's ARM endpoint. +func (con *Connection) Endpoint() string { + return con.ep +} + +// NewPipeline creates a pipeline from the connection's options. +// The telemetry policy, when enabled, will use the specified module and version info. +func (con *Connection) NewPipeline(module, version string) pipeline.Pipeline { + policies := []policy.Policy{} + if !con.opt.Telemetry.Disabled { + policies = append(policies, azruntime.NewTelemetryPolicy(module, version, &con.opt.Telemetry)) + } + if !con.opt.DisableRPRegistration { + regRPOpts := armruntime.RegistrationOptions{ + HTTPClient: con.opt.HTTPClient, + Logging: con.opt.Logging, + Retry: con.opt.Retry, + Telemetry: con.opt.Telemetry, + } + policies = append(policies, armruntime.NewRPRegistrationPolicy(con.ep, con.cred, ®RPOpts)) + } + policies = append(policies, con.opt.PerCallPolicies...) + policies = append(policies, azruntime.NewRetryPolicy(&con.opt.Retry)) + policies = append(policies, con.opt.PerRetryPolicies...) + policies = append(policies, + con.cred.NewAuthenticationPolicy( + azruntime.AuthenticationOptions{ + TokenRequest: policy.TokenRequestOptions{ + Scopes: []string{shared.EndpointToScope(con.ep)}, + }, + AuxiliaryTenants: con.opt.AuxiliaryTenants, + }, + ), + azruntime.NewLogPolicy(&con.opt.Logging)) + return azruntime.NewPipeline(con.opt.HTTPClient, policies...) +} diff --git a/sdk/azcore/arm/connection_test.go b/sdk/azcore/arm/connection_test.go new file mode 100644 index 000000000000..be0bf26453e6 --- /dev/null +++ b/sdk/azcore/arm/connection_test.go @@ -0,0 +1,202 @@ +//go:build go1.16 +// +build go1.16 + +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package arm + +import ( + "context" + "net/http" + "strings" + "testing" + "time" + + "github.com/Azure/azure-sdk-for-go/sdk/azcore" + armruntime "github.com/Azure/azure-sdk-for-go/sdk/azcore/arm/runtime" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/pipeline" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/log" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/policy" + azruntime "github.com/Azure/azure-sdk-for-go/sdk/azcore/runtime" + "github.com/Azure/azure-sdk-for-go/sdk/internal/mock" +) + +type mockTokenCred struct{} + +func (mockTokenCred) NewAuthenticationPolicy(azruntime.AuthenticationOptions) policy.Policy { + return pipeline.PolicyFunc(func(req *policy.Request) (*http.Response, error) { + return req.Next() + }) +} + +func (mockTokenCred) GetToken(context.Context, policy.TokenRequestOptions) (*azcore.AccessToken, error) { + return &azcore.AccessToken{ + Token: "abc123", + ExpiresOn: time.Now().Add(1 * time.Hour), + }, nil +} + +const rpUnregisteredResp = `{ + "error":{ + "code":"MissingSubscriptionRegistration", + "message":"The subscription registration is in 'Unregistered' state. The subscription must be registered to use namespace 'Microsoft.Storage'. See https://aka.ms/rps-not-found for how to register subscriptions.", + "details":[{ + "code":"MissingSubscriptionRegistration", + "target":"Microsoft.Storage", + "message":"The subscription registration is in 'Unregistered' state. The subscription must be registered to use namespace 'Microsoft.Storage'. See https://aka.ms/rps-not-found for how to register subscriptions." + } + ] + } +}` + +func TestNewDefaultConnection(t *testing.T) { + opt := ConnectionOptions{} + con := NewDefaultConnection(mockTokenCred{}, &opt) + if ep := con.Endpoint(); ep != AzurePublicCloud { + t.Fatalf("unexpected endpoint %s", ep) + } +} + +func TestNewConnection(t *testing.T) { + const customEndpoint = "https://contoso.com/fake/endpoint" + con := NewConnection(customEndpoint, mockTokenCred{}, nil) + if ep := con.Endpoint(); ep != customEndpoint { + t.Fatalf("unexpected endpoint %s", ep) + } +} + +func TestNewConnectionWithOptions(t *testing.T) { + srv, close := mock.NewServer() + defer close() + srv.AppendResponse() + opt := ConnectionOptions{} + opt.HTTPClient = srv + con := NewConnection(srv.URL(), mockTokenCred{}, &opt) + if ep := con.Endpoint(); ep != srv.URL() { + t.Fatalf("unexpected endpoint %s", ep) + } + req, err := azruntime.NewRequest(context.Background(), http.MethodGet, srv.URL()) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + resp, err := con.NewPipeline("armtest", "v1.2.3").Do(req) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + if resp.StatusCode != http.StatusOK { + t.Fatalf("unexpected status code: %d", resp.StatusCode) + } + if ua := resp.Request.Header.Get("User-Agent"); !strings.HasPrefix(ua, "azsdk-go-armtest/v1.2.3") { + t.Fatalf("unexpected User-Agent %s", ua) + } +} + +func TestNewConnectionWithCustomTelemetry(t *testing.T) { + const myTelemetry = "something" + srv, close := mock.NewServer() + defer close() + srv.AppendResponse() + opt := ConnectionOptions{} + opt.HTTPClient = srv + opt.Telemetry.ApplicationID = myTelemetry + con := NewConnection(srv.URL(), mockTokenCred{}, &opt) + if ep := con.Endpoint(); ep != srv.URL() { + t.Fatalf("unexpected endpoint %s", ep) + } + if opt.Telemetry.ApplicationID != myTelemetry { + t.Fatalf("telemetry was modified: %s", opt.Telemetry.ApplicationID) + } + req, err := azruntime.NewRequest(context.Background(), http.MethodGet, srv.URL()) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + resp, err := con.NewPipeline("armtest", "v1.2.3").Do(req) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + if resp.StatusCode != http.StatusOK { + t.Fatalf("unexpected status code: %d", resp.StatusCode) + } + if ua := resp.Request.Header.Get("User-Agent"); !strings.HasPrefix(ua, myTelemetry+" "+"azsdk-go-armtest/v1.2.3") { + t.Fatalf("unexpected User-Agent %s", ua) + } +} + +func TestDisableAutoRPRegistration(t *testing.T) { + srv, close := mock.NewServer() + defer close() + // initial response that RP is unregistered + srv.SetResponse(mock.WithStatusCode(http.StatusConflict), mock.WithBody([]byte(rpUnregisteredResp))) + con := NewConnection(srv.URL(), mockTokenCred{}, &ConnectionOptions{DisableRPRegistration: true}) + if ep := con.Endpoint(); ep != srv.URL() { + t.Fatalf("unexpected endpoint %s", ep) + } + req, err := azruntime.NewRequest(context.Background(), http.MethodGet, srv.URL()) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + // log only RP registration + log.SetClassifications(armruntime.LogRPRegistration) + defer func() { + // reset logging + log.SetClassifications() + }() + logEntries := 0 + log.SetListener(func(cls log.Classification, msg string) { + logEntries++ + }) + resp, err := con.NewPipeline("armtest", "v1.2.3").Do(req) + if err != nil { + t.Fatal(err) + } + if resp.StatusCode != http.StatusConflict { + t.Fatalf("unexpected status code %d:", resp.StatusCode) + } + // shouldn't be any log entries + if logEntries != 0 { + t.Fatalf("expected 0 log entries, got %d", logEntries) + } +} + +// policy that tracks the number of times it was invoked +type countingPolicy struct { + count int +} + +func (p *countingPolicy) Do(req *policy.Request) (*http.Response, error) { + p.count++ + return req.Next() +} + +func TestConnectionWithCustomPolicies(t *testing.T) { + srv, close := mock.NewServer() + defer close() + // initial response is a failure to trigger retry + srv.AppendResponse(mock.WithStatusCode(http.StatusInternalServerError)) + srv.AppendResponse(mock.WithStatusCode(http.StatusOK)) + perCallPolicy := countingPolicy{} + perRetryPolicy := countingPolicy{} + con := NewConnection(srv.URL(), mockTokenCred{}, &ConnectionOptions{ + DisableRPRegistration: true, + PerCallPolicies: []policy.Policy{&perCallPolicy}, + PerRetryPolicies: []policy.Policy{&perRetryPolicy}, + }) + req, err := azruntime.NewRequest(context.Background(), http.MethodGet, srv.URL()) + if err != nil { + t.Fatal(err) + } + resp, err := con.NewPipeline("armtest", "v1.2.3").Do(req) + if err != nil { + t.Fatal(err) + } + if resp.StatusCode != http.StatusOK { + t.Fatalf("unexpected status code %d", resp.StatusCode) + } + if perCallPolicy.count != 1 { + t.Fatalf("unexpected per call policy count %d", perCallPolicy.count) + } + if perRetryPolicy.count != 2 { + t.Fatalf("unexpected per retry policy count %d", perRetryPolicy.count) + } +} diff --git a/sdk/azcore/arm/internal/pollers/async/async.go b/sdk/azcore/arm/internal/pollers/async/async.go new file mode 100644 index 000000000000..3a3cd0a3ca96 --- /dev/null +++ b/sdk/azcore/arm/internal/pollers/async/async.go @@ -0,0 +1,139 @@ +//go:build go1.16 +// +build go1.16 + +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package async + +import ( + "errors" + "fmt" + "net/http" + + armpollers "github.com/Azure/azure-sdk-for-go/sdk/azcore/arm/internal/pollers" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/pollers" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/shared" + "github.com/Azure/azure-sdk-for-go/sdk/internal/log" +) + +// Kind is the identifier of this type in a resume token. +const Kind = "Azure-AsyncOperation" + +const ( + finalStateAsync = "azure-async-operation" + finalStateLoc = "location" //nolint + finalStateOrig = "original-uri" +) + +// Applicable returns true if the LRO is using Azure-AsyncOperation. +func Applicable(resp *http.Response) bool { + return resp.Header.Get(shared.HeaderAzureAsync) != "" +} + +// Poller is an LRO poller that uses the Azure-AsyncOperation pattern. +type Poller struct { + // The poller's type, used for resume token processing. + Type string `json:"type"` + + // The URL from Azure-AsyncOperation header. + AsyncURL string `json:"asyncURL"` + + // The URL from Location header. + LocURL string `json:"locURL"` + + // The URL from the initial LRO request. + OrigURL string `json:"origURL"` + + // The HTTP method from the initial LRO request. + Method string `json:"method"` + + // The value of final-state-via from swagger, can be the empty string. + FinalState string `json:"finalState"` + + // The LRO's current state. + CurState string `json:"state"` +} + +// New creates a new Poller from the provided initial response and final-state type. +func New(resp *http.Response, finalState string, pollerID string) (*Poller, error) { + log.Write(log.LongRunningOperation, "Using Azure-AsyncOperation poller.") + asyncURL := resp.Header.Get(shared.HeaderAzureAsync) + if asyncURL == "" { + return nil, errors.New("response is missing Azure-AsyncOperation header") + } + if !pollers.IsValidURL(asyncURL) { + return nil, fmt.Errorf("invalid polling URL %s", asyncURL) + } + p := &Poller{ + Type: pollers.MakeID(pollerID, Kind), + AsyncURL: asyncURL, + LocURL: resp.Header.Get(shared.HeaderLocation), + OrigURL: resp.Request.URL.String(), + Method: resp.Request.Method, + FinalState: finalState, + } + // check for provisioning state + state, err := armpollers.GetProvisioningState(resp) + if errors.Is(err, shared.ErrNoBody) || state == "" { + // NOTE: the ARM RPC spec explicitly states that for async PUT the initial response MUST + // contain a provisioning state. to maintain compat with track 1 and other implementations + // we are explicitly relaxing this requirement. + /*if resp.Request.Method == http.MethodPut { + // initial response for a PUT requires a provisioning state + return nil, err + }*/ + // for DELETE/PATCH/POST, provisioning state is optional + state = pollers.StatusInProgress + } else if err != nil { + return nil, err + } + p.CurState = state + return p, nil +} + +// Done returns true if the LRO has reached a terminal state. +func (p *Poller) Done() bool { + return pollers.IsTerminalState(p.Status()) +} + +// Update updates the Poller from the polling response. +func (p *Poller) Update(resp *http.Response) error { + state, err := armpollers.GetStatus(resp) + if err != nil { + return err + } else if state == "" { + return errors.New("the response did not contain a status") + } + p.CurState = state + return nil +} + +// FinalGetURL returns the URL to perform a final GET for the payload, or the empty string if not required. +func (p *Poller) FinalGetURL() string { + if p.Method == http.MethodPatch || p.Method == http.MethodPut { + // for PATCH and PUT, the final GET is on the original resource URL + return p.OrigURL + } else if p.Method == http.MethodPost { + if p.FinalState == finalStateAsync { + return "" + } else if p.FinalState == finalStateOrig { + return p.OrigURL + } else if p.LocURL != "" { + // ideally FinalState would be set to "location" but it isn't always. + // must check last due to more permissive condition. + return p.LocURL + } + } + return "" +} + +// URL returns the polling URL. +func (p *Poller) URL() string { + return p.AsyncURL +} + +// Status returns the status of the LRO. +func (p *Poller) Status() string { + return p.CurState +} diff --git a/sdk/azcore/arm/internal/pollers/async/async_test.go b/sdk/azcore/arm/internal/pollers/async/async_test.go new file mode 100644 index 000000000000..a87f503a6a77 --- /dev/null +++ b/sdk/azcore/arm/internal/pollers/async/async_test.go @@ -0,0 +1,180 @@ +//go:build go1.16 +// +build go1.16 + +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package async + +import ( + "io" + "io/ioutil" + "net/http" + "strings" + "testing" + + "github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/shared" +) + +const ( + fakePollingURL = "https://foo.bar.baz/status" + fakeResourceURL = "https://foo.bar.baz/resource" +) + +func initialResponse(method string, resp io.Reader) *http.Response { + req, err := http.NewRequest(method, fakeResourceURL, nil) + if err != nil { + panic(err) + } + return &http.Response{ + Body: ioutil.NopCloser(resp), + Header: http.Header{}, + Request: req, + } +} + +func pollingResponse(resp io.Reader) *http.Response { + return &http.Response{ + Body: ioutil.NopCloser(resp), + Header: http.Header{}, + } +} + +func TestApplicable(t *testing.T) { + resp := &http.Response{ + Header: http.Header{}, + } + if Applicable(resp) { + t.Fatal("missing Azure-AsyncOperation should not be applicable") + } + resp.Header.Set(shared.HeaderAzureAsync, fakePollingURL) + if !Applicable(resp) { + t.Fatal("having Azure-AsyncOperation should be applicable") + } +} + +func TestNew(t *testing.T) { + const jsonBody = `{ "properties": { "provisioningState": "Started" } }` + resp := initialResponse(http.MethodPut, strings.NewReader(jsonBody)) + resp.Header.Set(shared.HeaderAzureAsync, fakePollingURL) + poller, err := New(resp, "", "pollerID") + if err != nil { + t.Fatal(err) + } + if poller.Done() { + t.Fatal("poller should not be done") + } + if u := poller.FinalGetURL(); u != fakeResourceURL { + t.Fatalf("unexpected final get URL %s", u) + } + if s := poller.Status(); s != "Started" { + t.Fatalf("unexpected status %s", s) + } + if u := poller.URL(); u != fakePollingURL { + t.Fatalf("unexpected polling URL %s", u) + } + if err := poller.Update(pollingResponse(strings.NewReader(`{ "status": "InProgress" }`))); err != nil { + t.Fatal(err) + } + if s := poller.Status(); s != "InProgress" { + t.Fatalf("unexpected status %s", s) + } +} + +func TestNewDeleteNoProvState(t *testing.T) { + resp := initialResponse(http.MethodDelete, http.NoBody) + resp.Header.Set(shared.HeaderAzureAsync, fakePollingURL) + poller, err := New(resp, "", "pollerID") + if err != nil { + t.Fatal(err) + } + if poller.Done() { + t.Fatal("poller should not be done") + } + if s := poller.Status(); s != "InProgress" { + t.Fatalf("unexpected status %s", s) + } +} + +func TestNewPutNoProvState(t *testing.T) { + // missing provisioning state on initial response + // NOTE: ARM RPC forbids this but we allow it for back-compat + resp := initialResponse(http.MethodPut, http.NoBody) + resp.Header.Set(shared.HeaderAzureAsync, fakePollingURL) + poller, err := New(resp, "", "pollerID") + if err != nil { + t.Fatal(err) + } + if poller.Done() { + t.Fatal("poller should not be done") + } + if s := poller.Status(); s != "InProgress" { + t.Fatalf("unexpected status %s", s) + } +} + +func TestNewFinalGetLocation(t *testing.T) { + const ( + jsonBody = `{ "properties": { "provisioningState": "Started" } }` + locURL = "https://foo.bar.baz/location" + ) + resp := initialResponse(http.MethodPost, strings.NewReader(jsonBody)) + resp.Header.Set(shared.HeaderAzureAsync, fakePollingURL) + resp.Header.Set(shared.HeaderLocation, locURL) + poller, err := New(resp, "location", "pollerID") + if err != nil { + t.Fatal(err) + } + if poller.Done() { + t.Fatal("poller should not be done") + } + if u := poller.FinalGetURL(); u != locURL { + t.Fatalf("unexpected final get URL %s", u) + } + if u := poller.URL(); u != fakePollingURL { + t.Fatalf("unexpected polling URL %s", u) + } +} + +func TestNewFinalGetOrigin(t *testing.T) { + const ( + jsonBody = `{ "properties": { "provisioningState": "Started" } }` + locURL = "https://foo.bar.baz/location" + ) + resp := initialResponse(http.MethodPost, strings.NewReader(jsonBody)) + resp.Header.Set(shared.HeaderAzureAsync, fakePollingURL) + resp.Header.Set(shared.HeaderLocation, locURL) + poller, err := New(resp, "original-uri", "pollerID") + if err != nil { + t.Fatal(err) + } + if poller.Done() { + t.Fatal("poller should not be done") + } + if u := poller.FinalGetURL(); u != fakeResourceURL { + t.Fatalf("unexpected final get URL %s", u) + } + if u := poller.URL(); u != fakePollingURL { + t.Fatalf("unexpected polling URL %s", u) + } +} + +func TestNewPutNoProvStateOnUpdate(t *testing.T) { + // missing provisioning state on initial response + // NOTE: ARM RPC forbids this but we allow it for back-compat + resp := initialResponse(http.MethodPut, http.NoBody) + resp.Header.Set(shared.HeaderAzureAsync, fakePollingURL) + poller, err := New(resp, "", "pollerID") + if err != nil { + t.Fatal(err) + } + if poller.Done() { + t.Fatal("poller should not be done") + } + if s := poller.Status(); s != "InProgress" { + t.Fatalf("unexpected status %s", s) + } + if err := poller.Update(pollingResponse(strings.NewReader("{}"))); err == nil { + t.Fatal("unexpected nil error") + } +} diff --git a/sdk/azcore/arm/internal/pollers/body/body.go b/sdk/azcore/arm/internal/pollers/body/body.go new file mode 100644 index 000000000000..ea5fa6b468e0 --- /dev/null +++ b/sdk/azcore/arm/internal/pollers/body/body.go @@ -0,0 +1,111 @@ +//go:build go1.16 +// +build go1.16 + +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package body + +import ( + "errors" + "net/http" + + armpollers "github.com/Azure/azure-sdk-for-go/sdk/azcore/arm/internal/pollers" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/pollers" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/shared" + "github.com/Azure/azure-sdk-for-go/sdk/internal/log" +) + +// Kind is the identifier of this type in a resume token. +const Kind = "Body" + +// Applicable returns true if the LRO is using no headers, just provisioning state. +// This is only applicable to PATCH and PUT methods and assumes no polling headers. +func Applicable(resp *http.Response) bool { + // we can't check for absense of headers due to some misbehaving services + // like redis that return a Location header but don't actually use that protocol + return resp.Request.Method == http.MethodPatch || resp.Request.Method == http.MethodPut +} + +// Poller is an LRO poller that uses the Body pattern. +type Poller struct { + // The poller's type, used for resume token processing. + Type string `json:"type"` + + // The URL for polling. + PollURL string `json:"pollURL"` + + // The LRO's current state. + CurState string `json:"state"` +} + +// New creates a new Poller from the provided initial response. +func New(resp *http.Response, pollerID string) (*Poller, error) { + log.Write(log.LongRunningOperation, "Using Body poller.") + p := &Poller{ + Type: pollers.MakeID(pollerID, Kind), + PollURL: resp.Request.URL.String(), + } + // default initial state to InProgress. depending on the HTTP + // status code and provisioning state, we might change the value. + curState := pollers.StatusInProgress + provState, err := armpollers.GetProvisioningState(resp) + if err != nil && !errors.Is(err, shared.ErrNoBody) { + return nil, err + } + if resp.StatusCode == http.StatusCreated && provState != "" { + // absense of provisioning state is ok for a 201, means the operation is in progress + curState = provState + } else if resp.StatusCode == http.StatusOK { + if provState != "" { + curState = provState + } else if provState == "" { + // for a 200, absense of provisioning state indicates success + curState = pollers.StatusSucceeded + } + } else if resp.StatusCode == http.StatusNoContent { + curState = pollers.StatusSucceeded + } + p.CurState = curState + return p, nil +} + +// URL returns the polling URL. +func (p *Poller) URL() string { + return p.PollURL +} + +// Done returns true if the LRO has reached a terminal state. +func (p *Poller) Done() bool { + return pollers.IsTerminalState(p.Status()) +} + +// Update updates the Poller from the polling response. +func (p *Poller) Update(resp *http.Response) error { + if resp.StatusCode == http.StatusNoContent { + p.CurState = pollers.StatusSucceeded + return nil + } + state, err := armpollers.GetProvisioningState(resp) + if errors.Is(err, shared.ErrNoBody) { + // a missing response body in non-204 case is an error + return err + } else if state == "" { + // a response body without provisioning state is considered terminal success + state = pollers.StatusSucceeded + } else if err != nil { + return err + } + p.CurState = state + return nil +} + +// FinalGetURL returns the empty string as no final GET is required for this poller type. +func (*Poller) FinalGetURL() string { + return "" +} + +// Status returns the status of the LRO. +func (p *Poller) Status() string { + return p.CurState +} diff --git a/sdk/azcore/arm/internal/pollers/body/body_test.go b/sdk/azcore/arm/internal/pollers/body/body_test.go new file mode 100644 index 000000000000..aa19670dbd4e --- /dev/null +++ b/sdk/azcore/arm/internal/pollers/body/body_test.go @@ -0,0 +1,207 @@ +//go:build go1.16 +// +build go1.16 + +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package body + +import ( + "errors" + "io" + "io/ioutil" + "net/http" + "strings" + "testing" + + "github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/shared" +) + +const ( + fakeResourceURL = "https://foo.bar.baz/resource" +) + +func initialResponse(method string, resp io.Reader) *http.Response { + req, err := http.NewRequest(method, fakeResourceURL, nil) + if err != nil { + panic(err) + } + return &http.Response{ + Body: ioutil.NopCloser(resp), + Header: http.Header{}, + Request: req, + } +} + +func pollingResponse(status int, resp io.Reader) *http.Response { + return &http.Response{ + Body: ioutil.NopCloser(resp), + Header: http.Header{}, + StatusCode: status, + } +} + +func TestApplicable(t *testing.T) { + resp := &http.Response{ + Header: http.Header{}, + Request: &http.Request{ + Method: http.MethodDelete, + }, + } + if Applicable(resp) { + t.Fatal("method DELETE should not be applicable") + } + resp.Request.Method = http.MethodPatch + if !Applicable(resp) { + t.Fatal("method PATCH should be applicable") + } + resp.Request.Method = http.MethodPut + if !Applicable(resp) { + t.Fatal("method PUT should be applicable") + } +} + +func TestNew(t *testing.T) { + const jsonBody = `{ "properties": { "provisioningState": "Started" } }` + resp := initialResponse(http.MethodPut, strings.NewReader(jsonBody)) + resp.StatusCode = http.StatusCreated + poller, err := New(resp, "pollerID") + if err != nil { + t.Fatal(err) + } + if poller.Done() { + t.Fatal("poller should not be done") + } + if u := poller.FinalGetURL(); u != "" { + t.Fatal("expected empty final GET URL") + } + if s := poller.Status(); s != "Started" { + t.Fatalf("unexpected status %s", s) + } + if u := poller.URL(); u != fakeResourceURL { + t.Fatalf("unexpected polling URL %s", u) + } + if err := poller.Update(pollingResponse(http.StatusOK, strings.NewReader(`{ "properties": { "provisioningState": "InProgress" } }`))); err != nil { + t.Fatal(err) + } + if s := poller.Status(); s != "InProgress" { + t.Fatalf("unexpected status %s", s) + } +} + +func TestUpdateNoProvStateFail(t *testing.T) { + const jsonBody = `{ "properties": { "provisioningState": "Started" } }` + resp := initialResponse(http.MethodPut, strings.NewReader(jsonBody)) + resp.StatusCode = http.StatusOK + poller, err := New(resp, "pollerID") + if err != nil { + t.Fatal(err) + } + if poller.Done() { + t.Fatal("poller should not be done") + } + if u := poller.FinalGetURL(); u != "" { + t.Fatal("expected empty final GET URL") + } + if s := poller.Status(); s != "Started" { + t.Fatalf("unexpected status %s", s) + } + if u := poller.URL(); u != fakeResourceURL { + t.Fatalf("unexpected polling URL %s", u) + } + err = poller.Update(pollingResponse(http.StatusOK, http.NoBody)) + if err == nil { + t.Fatal("unexpected nil error") + } + if !errors.Is(err, shared.ErrNoBody) { + t.Fatalf("unexpected error type %T", err) + } +} + +func TestUpdateNoProvStateSuccess(t *testing.T) { + const jsonBody = `{ "properties": { "provisioningState": "Started" } }` + resp := initialResponse(http.MethodPut, strings.NewReader(jsonBody)) + resp.StatusCode = http.StatusOK + poller, err := New(resp, "pollerID") + if err != nil { + t.Fatal(err) + } + if poller.Done() { + t.Fatal("poller should not be done") + } + if u := poller.FinalGetURL(); u != "" { + t.Fatal("expected empty final GET URL") + } + if s := poller.Status(); s != "Started" { + t.Fatalf("unexpected status %s", s) + } + if u := poller.URL(); u != fakeResourceURL { + t.Fatalf("unexpected polling URL %s", u) + } + err = poller.Update(pollingResponse(http.StatusOK, strings.NewReader(`{}`))) + if err != nil { + t.Fatal(err) + } +} + +func TestUpdateNoProvState204(t *testing.T) { + const jsonBody = `{ "properties": { "provisioningState": "Started" } }` + resp := initialResponse(http.MethodPut, strings.NewReader(jsonBody)) + resp.StatusCode = http.StatusOK + poller, err := New(resp, "pollerID") + if err != nil { + t.Fatal(err) + } + if poller.Done() { + t.Fatal("poller should not be done") + } + if u := poller.FinalGetURL(); u != "" { + t.Fatal("expected empty final GET URL") + } + if s := poller.Status(); s != "Started" { + t.Fatalf("unexpected status %s", s) + } + if u := poller.URL(); u != fakeResourceURL { + t.Fatalf("unexpected polling URL %s", u) + } + err = poller.Update(pollingResponse(http.StatusNoContent, http.NoBody)) + if err != nil { + t.Fatal(err) + } +} + +func TestNewNoInitialProvStateOK(t *testing.T) { + resp := initialResponse(http.MethodPut, http.NoBody) + resp.StatusCode = http.StatusOK + poller, err := New(resp, "pollerID") + if err != nil { + t.Fatal(err) + } + if !poller.Done() { + t.Fatal("poller not be done") + } + if u := poller.FinalGetURL(); u != "" { + t.Fatal("expected empty final GET URL") + } + if s := poller.Status(); s != "Succeeded" { + t.Fatalf("unexpected status %s", s) + } +} + +func TestNewNoInitialProvStateNC(t *testing.T) { + resp := initialResponse(http.MethodPut, http.NoBody) + resp.StatusCode = http.StatusNoContent + poller, err := New(resp, "pollerID") + if err != nil { + t.Fatal(err) + } + if !poller.Done() { + t.Fatal("poller not be done") + } + if u := poller.FinalGetURL(); u != "" { + t.Fatal("expected empty final GET URL") + } + if s := poller.Status(); s != "Succeeded" { + t.Fatalf("unexpected status %s", s) + } +} diff --git a/sdk/azcore/arm/internal/pollers/loc/loc.go b/sdk/azcore/arm/internal/pollers/loc/loc.go new file mode 100644 index 000000000000..a1b8a23234fa --- /dev/null +++ b/sdk/azcore/arm/internal/pollers/loc/loc.go @@ -0,0 +1,104 @@ +//go:build go1.16 +// +build go1.16 + +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package loc + +import ( + "errors" + "fmt" + "net/http" + + armpollers "github.com/Azure/azure-sdk-for-go/sdk/azcore/arm/internal/pollers" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/pollers" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/shared" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/runtime" + "github.com/Azure/azure-sdk-for-go/sdk/internal/log" +) + +// Kind is the identifier of this type in a resume token. +const Kind = "ARM-Location" + +// Applicable returns true if the LRO is using Location. +func Applicable(resp *http.Response) bool { + return resp.StatusCode == http.StatusAccepted && resp.Header.Get(shared.HeaderLocation) != "" +} + +// Poller is an LRO poller that uses the Location pattern. +type Poller struct { + // The poller's type, used for resume token processing. + Type string `json:"type"` + + // The URL for polling. + PollURL string `json:"pollURL"` + + // The LRO's current state. + CurState string `json:"state"` +} + +// New creates a new Poller from the provided initial response. +func New(resp *http.Response, pollerID string) (*Poller, error) { + log.Write(log.LongRunningOperation, "Using Location poller.") + locURL := resp.Header.Get(shared.HeaderLocation) + if locURL == "" { + return nil, errors.New("response is missing Location header") + } + if !pollers.IsValidURL(locURL) { + return nil, fmt.Errorf("invalid polling URL %s", locURL) + } + p := &Poller{ + Type: pollers.MakeID(pollerID, Kind), + PollURL: locURL, + CurState: pollers.StatusInProgress, + } + return p, nil +} + +// URL returns the polling URL. +func (p *Poller) URL() string { + return p.PollURL +} + +// Done returns true if the LRO has reached a terminal state. +func (p *Poller) Done() bool { + return pollers.IsTerminalState(p.Status()) +} + +// Update updates the Poller from the polling response. +func (p *Poller) Update(resp *http.Response) error { + // location polling can return an updated polling URL + if h := resp.Header.Get(shared.HeaderLocation); h != "" { + p.PollURL = h + } + if runtime.HasStatusCode(resp, http.StatusOK, http.StatusCreated) { + // if a 200/201 returns a provisioning state, use that instead + state, err := armpollers.GetProvisioningState(resp) + if err != nil && !errors.Is(err, shared.ErrNoBody) { + return err + } + if state != "" { + p.CurState = state + } else { + // a 200/201 with no provisioning state indicates success + p.CurState = pollers.StatusSucceeded + } + } else if resp.StatusCode == http.StatusNoContent { + p.CurState = pollers.StatusSucceeded + } else if resp.StatusCode > 399 && resp.StatusCode < 500 { + p.CurState = pollers.StatusFailed + } + // a 202 falls through, means the LRO is still in progress and we don't check for provisioning state + return nil +} + +// FinalGetURL returns the empty string as no final GET is required for this poller type. +func (p *Poller) FinalGetURL() string { + return "" +} + +// Status returns the status of the LRO. +func (p *Poller) Status() string { + return p.CurState +} diff --git a/sdk/azcore/arm/internal/pollers/loc/loc_test.go b/sdk/azcore/arm/internal/pollers/loc/loc_test.go new file mode 100644 index 000000000000..06365a1d6ec6 --- /dev/null +++ b/sdk/azcore/arm/internal/pollers/loc/loc_test.go @@ -0,0 +1,133 @@ +//go:build go1.16 +// +build go1.16 + +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package loc + +import ( + "io" + "io/ioutil" + "net/http" + "strings" + "testing" + + "github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/shared" +) + +const ( + fakePollingURL1 = "https://foo.bar.baz/status" + fakePollingURL2 = "https://foo.bar.baz/updated" +) + +func initialResponse(method string) *http.Response { + return &http.Response{ + Header: http.Header{}, + StatusCode: http.StatusAccepted, + } +} + +func pollingResponse(statusCode int, body io.Reader) *http.Response { + return &http.Response{ + Body: ioutil.NopCloser(body), + Header: http.Header{}, + StatusCode: statusCode, + } +} + +func TestApplicable(t *testing.T) { + resp := &http.Response{ + Header: http.Header{}, + StatusCode: http.StatusAccepted, + } + if Applicable(resp) { + t.Fatal("missing Location should not be applicable") + } + resp.Header.Set(shared.HeaderLocation, fakePollingURL1) + if !Applicable(resp) { + t.Fatal("having Location should be applicable") + } +} + +func TestNew(t *testing.T) { + resp := initialResponse(http.MethodPut) + resp.Header.Set(shared.HeaderLocation, fakePollingURL1) + poller, err := New(resp, "pollerID") + if err != nil { + t.Fatal(err) + } + if poller.Done() { + t.Fatal("poller should not be done") + } + if u := poller.FinalGetURL(); u != "" { + t.Fatal("expected empty final GET URL") + } + if s := poller.Status(); s != "InProgress" { + t.Fatalf("unexpected status %s", s) + } + if u := poller.URL(); u != fakePollingURL1 { + t.Fatalf("unexpected polling URL %s", u) + } + pr := pollingResponse(http.StatusAccepted, http.NoBody) + pr.Header.Set(shared.HeaderLocation, fakePollingURL2) + if err := poller.Update(pr); err != nil { + t.Fatal(err) + } + if u := poller.URL(); u != fakePollingURL2 { + t.Fatalf("unexpected polling URL %s", u) + } + if err := poller.Update(pollingResponse(http.StatusNoContent, http.NoBody)); err != nil { + t.Fatal(err) + } + if s := poller.Status(); s != "Succeeded" { + t.Fatalf("unexpected status %s", s) + } + if err := poller.Update(pollingResponse(http.StatusConflict, http.NoBody)); err != nil { + t.Fatal(err) + } + if s := poller.Status(); s != "Failed" { + t.Fatalf("unexpected status %s", s) + } +} + +func TestUpdateWithProvState(t *testing.T) { + resp := initialResponse(http.MethodPut) + resp.Header.Set(shared.HeaderLocation, fakePollingURL1) + poller, err := New(resp, "pollerID") + if err != nil { + t.Fatal(err) + } + if poller.Done() { + t.Fatal("poller should not be done") + } + if u := poller.FinalGetURL(); u != "" { + t.Fatal("expected empty final GET URL") + } + if s := poller.Status(); s != "InProgress" { + t.Fatalf("unexpected status %s", s) + } + if u := poller.URL(); u != fakePollingURL1 { + t.Fatalf("unexpected polling URL %s", u) + } + pr := pollingResponse(http.StatusAccepted, http.NoBody) + pr.Header.Set(shared.HeaderLocation, fakePollingURL2) + if err := poller.Update(pr); err != nil { + t.Fatal(err) + } + if u := poller.URL(); u != fakePollingURL2 { + t.Fatalf("unexpected polling URL %s", u) + } + if err := poller.Update(pollingResponse(http.StatusOK, strings.NewReader(`{ "properties": { "provisioningState": "Updating" } }`))); err != nil { + t.Fatal(err) + } + if s := poller.Status(); s != "Updating" { + t.Fatalf("unexpected status %s", s) + } + if err := poller.Update(pollingResponse(http.StatusOK, http.NoBody)); err != nil { + t.Fatal(err) + } + if s := poller.Status(); s != "Succeeded" { + t.Fatalf("unexpected status %s", s) + } +} diff --git a/sdk/azcore/arm/internal/pollers/pollers.go b/sdk/azcore/arm/internal/pollers/pollers.go new file mode 100644 index 000000000000..3b9f5581c017 --- /dev/null +++ b/sdk/azcore/arm/internal/pollers/pollers.go @@ -0,0 +1,68 @@ +//go:build go1.16 +// +build go1.16 + +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package pollers + +import ( + "net/http" + + "github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/shared" +) + +// provisioningState returns the provisioning state from the response or the empty string. +func provisioningState(jsonBody map[string]interface{}) string { + jsonProps, ok := jsonBody["properties"] + if !ok { + return "" + } + props, ok := jsonProps.(map[string]interface{}) + if !ok { + return "" + } + rawPs, ok := props["provisioningState"] + if !ok { + return "" + } + ps, ok := rawPs.(string) + if !ok { + return "" + } + return ps +} + +// status returns the status from the response or the empty string. +func status(jsonBody map[string]interface{}) string { + rawStatus, ok := jsonBody["status"] + if !ok { + return "" + } + status, ok := rawStatus.(string) + if !ok { + return "" + } + return status +} + +// GetStatus returns the LRO's status from the response body. +// Typically used for Azure-AsyncOperation flows. +// If there is no status in the response body the empty string is returned. +func GetStatus(resp *http.Response) (string, error) { + jsonBody, err := shared.GetJSON(resp) + if err != nil { + return "", err + } + return status(jsonBody), nil +} + +// GetProvisioningState returns the LRO's state from the response body. +// If there is no state in the response body the empty string is returned. +func GetProvisioningState(resp *http.Response) (string, error) { + jsonBody, err := shared.GetJSON(resp) + if err != nil { + return "", err + } + return provisioningState(jsonBody), nil +} diff --git a/sdk/azcore/arm/internal/pollers/pollers_test.go b/sdk/azcore/arm/internal/pollers/pollers_test.go new file mode 100644 index 000000000000..10808c908256 --- /dev/null +++ b/sdk/azcore/arm/internal/pollers/pollers_test.go @@ -0,0 +1,91 @@ +//go:build go1.16 +// +build go1.16 + +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package pollers + +import ( + "errors" + "io/ioutil" + "net/http" + "strings" + "testing" + + "github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/shared" +) + +func TestGetStatusSuccess(t *testing.T) { + const jsonBody = `{ "status": "InProgress" }` + resp := &http.Response{ + Body: ioutil.NopCloser(strings.NewReader(jsonBody)), + } + status, err := GetStatus(resp) + if err != nil { + t.Fatal(err) + } + if status != "InProgress" { + t.Fatalf("unexpected status %s", status) + } +} + +func TestGetNoBody(t *testing.T) { + resp := &http.Response{ + Body: http.NoBody, + } + status, err := GetStatus(resp) + if !errors.Is(err, shared.ErrNoBody) { + t.Fatalf("unexpected error %T", err) + } + if status != "" { + t.Fatal("expected empty status") + } + status, err = GetProvisioningState(resp) + if !errors.Is(err, shared.ErrNoBody) { + t.Fatalf("unexpected error %T", err) + } + if status != "" { + t.Fatal("expected empty status") + } +} + +func TestGetStatusError(t *testing.T) { + resp := &http.Response{ + Body: ioutil.NopCloser(strings.NewReader("{}")), + } + status, err := GetStatus(resp) + if err != nil { + t.Fatal(err) + } + if status != "" { + t.Fatalf("expected empty status, got %s", status) + } +} + +func TestGetProvisioningState(t *testing.T) { + const jsonBody = `{ "properties": { "provisioningState": "Canceled" } }` + resp := &http.Response{ + Body: ioutil.NopCloser(strings.NewReader(jsonBody)), + } + state, err := GetProvisioningState(resp) + if err != nil { + t.Fatal(err) + } + if state != "Canceled" { + t.Fatalf("unexpected status %s", state) + } +} + +func TestGetProvisioningStateError(t *testing.T) { + resp := &http.Response{ + Body: ioutil.NopCloser(strings.NewReader("{}")), + } + state, err := GetProvisioningState(resp) + if err != nil { + t.Fatal(err) + } + if state != "" { + t.Fatalf("expected empty provisioning state, got %s", state) + } +} diff --git a/sdk/azcore/arm/runtime/policy_register_rp.go b/sdk/azcore/arm/runtime/policy_register_rp.go new file mode 100644 index 000000000000..9b5e16250277 --- /dev/null +++ b/sdk/azcore/arm/runtime/policy_register_rp.go @@ -0,0 +1,384 @@ +//go:build go1.16 +// +build go1.16 + +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package runtime + +import ( + "context" + "errors" + "fmt" + "io/ioutil" + "net/http" + "net/url" + "strings" + "time" + + "github.com/Azure/azure-sdk-for-go/sdk/azcore" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/pipeline" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/shared" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/policy" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/runtime" + "github.com/Azure/azure-sdk-for-go/sdk/internal/log" +) + +const ( + // LogRPRegistration entries contain information specific to the automatic registration of an RP. + // Entries of this classification are written IFF the policy needs to take any action. + LogRPRegistration log.Classification = "RPRegistration" +) + +// RegistrationOptions configures the registration policy's behavior. +// All zero-value fields will be initialized with their default values. +type RegistrationOptions struct { + // MaxAttempts is the total number of times to attempt automatic registration + // in the event that an attempt fails. + // The default value is 3. + // Set to a value less than zero to disable the policy. + MaxAttempts int + + // PollingDelay is the amount of time to sleep between polling intervals. + // The default value is 15 seconds. + // A value less than zero means no delay between polling intervals (not recommended). + PollingDelay time.Duration + + // PollingDuration is the amount of time to wait before abandoning polling. + // The default valule is 5 minutes. + // NOTE: Setting this to a small value might cause the policy to prematurely fail. + PollingDuration time.Duration + + // HTTPClient sets the transport for making HTTP requests. + HTTPClient policy.Transporter + + // Retry configures the built-in retry policy behavior. + Retry policy.RetryOptions + + // Telemetry configures the built-in telemetry policy behavior. + Telemetry policy.TelemetryOptions + + // Logging configures the built-in logging policy behavior. + Logging policy.LogOptions +} + +// init sets any default values +func (r *RegistrationOptions) init() { + if r.MaxAttempts == 0 { + r.MaxAttempts = 3 + } else if r.MaxAttempts < 0 { + r.MaxAttempts = 0 + } + if r.PollingDelay == 0 { + r.PollingDelay = 15 * time.Second + } else if r.PollingDelay < 0 { + r.PollingDelay = 0 + } + if r.PollingDuration == 0 { + r.PollingDuration = 5 * time.Minute + } +} + +// NewRPRegistrationPolicy creates a policy object configured using the specified endpoint, +// credentials and options. The policy controls if an unregistered resource provider should +// automatically be registered. See https://aka.ms/rps-not-found for more information. +// Pass nil to accept the default options; this is the same as passing a zero-value options. +func NewRPRegistrationPolicy(endpoint string, cred azcore.Credential, o *RegistrationOptions) policy.Policy { + if o == nil { + o = &RegistrationOptions{} + } + p := &rpRegistrationPolicy{ + endpoint: endpoint, + pipeline: runtime.NewPipeline(o.HTTPClient, + runtime.NewTelemetryPolicy(shared.Module, shared.Version, &o.Telemetry), + runtime.NewRetryPolicy(&o.Retry), + cred.NewAuthenticationPolicy(runtime.AuthenticationOptions{TokenRequest: policy.TokenRequestOptions{Scopes: []string{shared.EndpointToScope(endpoint)}}}), + runtime.NewLogPolicy(&o.Logging)), + options: *o, + } + // init the copy + p.options.init() + return p +} + +type rpRegistrationPolicy struct { + endpoint string + pipeline pipeline.Pipeline + options RegistrationOptions +} + +func (r *rpRegistrationPolicy) Do(req *policy.Request) (*http.Response, error) { + if r.options.MaxAttempts == 0 { + // policy is disabled + return req.Next() + } + const unregisteredRPCode = "MissingSubscriptionRegistration" + const registeredState = "Registered" + var rp string + var resp *http.Response + for attempts := 0; attempts < r.options.MaxAttempts; attempts++ { + var err error + // make the original request + resp, err = req.Next() + // getting a 409 is the first indication that the RP might need to be registered, check error response + if err != nil || resp.StatusCode != http.StatusConflict { + return resp, err + } + var reqErr requestError + if err = runtime.UnmarshalAsJSON(resp, &reqErr); err != nil { + return resp, err + } + if reqErr.ServiceError == nil { + return resp, errors.New("missing error information") + } + if !strings.EqualFold(reqErr.ServiceError.Code, unregisteredRPCode) { + // not a 409 due to unregistered RP + return resp, err + } + // RP needs to be registered. start by getting the subscription ID from the original request + subID, err := getSubscription(req.Raw().URL.Path) + if err != nil { + return resp, err + } + // now get the RP from the error + rp, err = getProvider(reqErr) + if err != nil { + return resp, err + } + logRegistrationExit := func(v interface{}) { + log.Writef(LogRPRegistration, "END registration for %s: %v", rp, v) + } + log.Writef(LogRPRegistration, "BEGIN registration for %s", rp) + // create client and make the registration request + // we use the scheme and host from the original request + rpOps := &providersOperations{ + p: r.pipeline, + u: r.endpoint, + subID: subID, + } + if _, err = rpOps.Register(req.Raw().Context(), rp); err != nil { + logRegistrationExit(err) + return resp, err + } + // RP was registered, however we need to wait for the registration to complete + pollCtx, pollCancel := context.WithTimeout(req.Raw().Context(), r.options.PollingDuration) + var lastRegState string + for { + // get the current registration state + getResp, err := rpOps.Get(pollCtx, rp) + if err != nil { + pollCancel() + logRegistrationExit(err) + return resp, err + } + if getResp.Provider.RegistrationState != nil && !strings.EqualFold(*getResp.Provider.RegistrationState, lastRegState) { + // registration state has changed, or was updated for the first time + lastRegState = *getResp.Provider.RegistrationState + log.Writef(LogRPRegistration, "registration state is %s", lastRegState) + } + if strings.EqualFold(lastRegState, registeredState) { + // registration complete + pollCancel() + logRegistrationExit(lastRegState) + break + } + // wait before trying again + select { + case <-time.After(r.options.PollingDelay): + // continue polling + case <-pollCtx.Done(): + pollCancel() + logRegistrationExit(pollCtx.Err()) + return resp, pollCtx.Err() + } + } + // RP was successfully registered, retry the original request + err = req.RewindBody() + if err != nil { + return resp, err + } + } + // if we get here it means we exceeded the number of attempts + return resp, fmt.Errorf("exceeded attempts to register %s", rp) +} + +func getSubscription(path string) (string, error) { + parts := strings.Split(path, "/") + for i, v := range parts { + if v == "subscriptions" && (i+1) < len(parts) { + return parts[i+1], nil + } + } + return "", fmt.Errorf("failed to obtain subscription ID from %s", path) +} + +func getProvider(re requestError) (string, error) { + if len(re.ServiceError.Details) > 0 { + return re.ServiceError.Details[0].Target, nil + } + return "", errors.New("unexpected empty Details") +} + +// minimal error definitions to simplify detection +type requestError struct { + ServiceError *serviceError `json:"error"` +} + +type serviceError struct { + Code string `json:"code"` + Details []serviceErrorDetails `json:"details"` +} + +type serviceErrorDetails struct { + Code string `json:"code"` + Target string `json:"target"` +} + +/////////////////////////////////////////////////////////////////////////////////////////////// +// the following code was copied from module armresources, providers.go and models.go +// only the minimum amount of code was copied to get this working and some edits were made. +/////////////////////////////////////////////////////////////////////////////////////////////// + +type providersOperations struct { + p pipeline.Pipeline + u string + subID string +} + +// Get - Gets the specified resource provider. +func (client *providersOperations) Get(ctx context.Context, resourceProviderNamespace string) (*ProviderResponse, error) { + req, err := client.getCreateRequest(ctx, resourceProviderNamespace) + if err != nil { + return nil, err + } + resp, err := client.p.Do(req) + if err != nil { + return nil, err + } + result, err := client.getHandleResponse(resp) + if err != nil { + return nil, err + } + return result, nil +} + +// getCreateRequest creates the Get request. +func (client *providersOperations) getCreateRequest(ctx context.Context, resourceProviderNamespace string) (*policy.Request, error) { + urlPath := "/subscriptions/{subscriptionId}/providers/{resourceProviderNamespace}" + urlPath = strings.ReplaceAll(urlPath, "{resourceProviderNamespace}", url.PathEscape(resourceProviderNamespace)) + urlPath = strings.ReplaceAll(urlPath, "{subscriptionId}", url.PathEscape(client.subID)) + req, err := runtime.NewRequest(ctx, http.MethodGet, runtime.JoinPaths(client.u, urlPath)) + if err != nil { + return nil, err + } + query := req.Raw().URL.Query() + query.Set("api-version", "2019-05-01") + req.Raw().URL.RawQuery = query.Encode() + return req, nil +} + +// getHandleResponse handles the Get response. +func (client *providersOperations) getHandleResponse(resp *http.Response) (*ProviderResponse, error) { + if !runtime.HasStatusCode(resp, http.StatusOK) { + return nil, client.getHandleError(resp) + } + result := ProviderResponse{RawResponse: resp} + err := runtime.UnmarshalAsJSON(resp, &result.Provider) + if err != nil { + return nil, err + } + return &result, err +} + +// getHandleError handles the Get error response. +func (client *providersOperations) getHandleError(resp *http.Response) error { + body, err := ioutil.ReadAll(resp.Body) + if err != nil { + return shared.NewResponseError(err, resp) + } + if len(body) == 0 { + return shared.NewResponseError(errors.New(resp.Status), resp) + } + return shared.NewResponseError(errors.New(string(body)), resp) +} + +// Register - Registers a subscription with a resource provider. +func (client *providersOperations) Register(ctx context.Context, resourceProviderNamespace string) (*ProviderResponse, error) { + req, err := client.registerCreateRequest(ctx, resourceProviderNamespace) + if err != nil { + return nil, err + } + resp, err := client.p.Do(req) + if err != nil { + return nil, err + } + result, err := client.registerHandleResponse(resp) + if err != nil { + return nil, err + } + return result, nil +} + +// registerCreateRequest creates the Register request. +func (client *providersOperations) registerCreateRequest(ctx context.Context, resourceProviderNamespace string) (*policy.Request, error) { + urlPath := "/subscriptions/{subscriptionId}/providers/{resourceProviderNamespace}/register" + urlPath = strings.ReplaceAll(urlPath, "{resourceProviderNamespace}", url.PathEscape(resourceProviderNamespace)) + urlPath = strings.ReplaceAll(urlPath, "{subscriptionId}", url.PathEscape(client.subID)) + req, err := runtime.NewRequest(ctx, http.MethodPost, runtime.JoinPaths(client.u, urlPath)) + if err != nil { + return nil, err + } + query := req.Raw().URL.Query() + query.Set("api-version", "2019-05-01") + req.Raw().URL.RawQuery = query.Encode() + return req, nil +} + +// registerHandleResponse handles the Register response. +func (client *providersOperations) registerHandleResponse(resp *http.Response) (*ProviderResponse, error) { + if !runtime.HasStatusCode(resp, http.StatusOK) { + return nil, client.registerHandleError(resp) + } + result := ProviderResponse{RawResponse: resp} + err := runtime.UnmarshalAsJSON(resp, &result.Provider) + if err != nil { + return nil, err + } + return &result, err +} + +// registerHandleError handles the Register error response. +func (client *providersOperations) registerHandleError(resp *http.Response) error { + body, err := ioutil.ReadAll(resp.Body) + if err != nil { + return shared.NewResponseError(err, resp) + } + if len(body) == 0 { + return shared.NewResponseError(errors.New(resp.Status), resp) + } + return shared.NewResponseError(errors.New(string(body)), resp) +} + +// ProviderResponse is the response envelope for operations that return a Provider type. +type ProviderResponse struct { + // Resource provider information. + Provider *Provider + + // RawResponse contains the underlying HTTP response. + RawResponse *http.Response +} + +// Provider - Resource provider information. +type Provider struct { + // The provider ID. + ID *string `json:"id,omitempty"` + + // The namespace of the resource provider. + Namespace *string `json:"namespace,omitempty"` + + // The registration policy of the resource provider. + RegistrationPolicy *string `json:"registrationPolicy,omitempty"` + + // The registration state of the resource provider. + RegistrationState *string `json:"registrationState,omitempty"` +} diff --git a/sdk/azcore/arm/runtime/policy_register_rp_test.go b/sdk/azcore/arm/runtime/policy_register_rp_test.go new file mode 100644 index 000000000000..05b5318f5567 --- /dev/null +++ b/sdk/azcore/arm/runtime/policy_register_rp_test.go @@ -0,0 +1,372 @@ +//go:build go1.16 +// +build go1.16 + +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package runtime + +import ( + "context" + "errors" + "net/http" + "strings" + "sync" + "testing" + "time" + + "github.com/Azure/azure-sdk-for-go/sdk/azcore" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/pipeline" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/log" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/policy" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/runtime" + "github.com/Azure/azure-sdk-for-go/sdk/internal/mock" +) + +const rpUnregisteredResp = `{ + "error":{ + "code":"MissingSubscriptionRegistration", + "message":"The subscription registration is in 'Unregistered' state. The subscription must be registered to use namespace 'Microsoft.Storage'. See https://aka.ms/rps-not-found for how to register subscriptions.", + "details":[{ + "code":"MissingSubscriptionRegistration", + "target":"Microsoft.Storage", + "message":"The subscription registration is in 'Unregistered' state. The subscription must be registered to use namespace 'Microsoft.Storage'. See https://aka.ms/rps-not-found for how to register subscriptions." + } + ] + } +}` + +// some content was omitted here as it's not relevant +const rpRegisteringResp = `{ + "id": "/subscriptions/00000000-0000-0000-0000-000000000000/providers/Microsoft.Storage", + "namespace": "Microsoft.Storage", + "registrationState": "Registering", + "registrationPolicy": "RegistrationRequired" +}` + +// some content was omitted here as it's not relevant +const rpRegisteredResp = `{ + "id": "/subscriptions/00000000-0000-0000-0000-000000000000/providers/Microsoft.Storage", + "namespace": "Microsoft.Storage", + "registrationState": "Registered", + "registrationPolicy": "RegistrationRequired" +}` + +const requestEndpoint = "/subscriptions/00000000-0000-0000-0000-000000000000/resourceGroups/fakeResourceGroupo/providers/Microsoft.Storage/storageAccounts/fakeAccountName" + +func testRPRegistrationOptions(t policy.Transporter) *RegistrationOptions { + def := RegistrationOptions{} + def.HTTPClient = t + def.PollingDelay = 100 * time.Millisecond + def.PollingDuration = 1 * time.Second + return &def +} + +type mockTokenCred struct{} + +func (mockTokenCred) NewAuthenticationPolicy(runtime.AuthenticationOptions) policy.Policy { + return pipeline.PolicyFunc(func(req *policy.Request) (*http.Response, error) { + return req.Next() + }) +} + +func (mockTokenCred) GetToken(context.Context, policy.TokenRequestOptions) (*azcore.AccessToken, error) { + return &azcore.AccessToken{ + Token: "abc123", + ExpiresOn: time.Now().Add(1 * time.Hour), + }, nil +} + +func TestRPRegistrationPolicySuccess(t *testing.T) { + srv, close := mock.NewServer() + defer close() + // initial response that RP is unregistered + srv.AppendResponse(mock.WithStatusCode(http.StatusConflict), mock.WithBody([]byte(rpUnregisteredResp))) + // polling responses to Register() and Get(), in progress + srv.RepeatResponse(5, mock.WithStatusCode(http.StatusOK), mock.WithBody([]byte(rpRegisteringResp))) + // polling response, successful registration + srv.AppendResponse(mock.WithStatusCode(http.StatusOK), mock.WithBody([]byte(rpRegisteredResp))) + // response for original request (different status code than any of the other responses) + srv.AppendResponse(mock.WithStatusCode(http.StatusAccepted)) + pl := runtime.NewPipeline(srv, NewRPRegistrationPolicy(srv.URL(), mockTokenCred{}, testRPRegistrationOptions(srv))) + req, err := runtime.NewRequest(context.Background(), http.MethodGet, runtime.JoinPaths(srv.URL(), requestEndpoint)) + if err != nil { + t.Fatal(err) + } + // log only RP registration + log.SetClassifications(LogRPRegistration) + defer func() { + // reset logging + log.SetClassifications() + }() + logEntries := 0 + log.SetListener(func(cls log.Classification, msg string) { + logEntries++ + }) + resp, err := pl.Do(req) + if err != nil { + t.Fatal(err) + } + if resp.StatusCode != http.StatusAccepted { + t.Fatalf("unexpected status code %d:", resp.StatusCode) + } + if resp.Request.URL.Path != requestEndpoint { + t.Fatalf("unexpected path in response %s", resp.Request.URL.Path) + } + // should be four entries + // 1st is for start + // 2nd is for first response to get state + // 3rd is when state transitions to success + // 4th is for end + if logEntries != 4 { + t.Fatalf("expected 4 log entries, got %d", logEntries) + } +} + +func TestRPRegistrationPolicyNA(t *testing.T) { + srv, close := mock.NewServer() + defer close() + // response indicates no RP registration is required, policy does nothing + srv.AppendResponse(mock.WithStatusCode(http.StatusOK)) + pl := runtime.NewPipeline(srv, NewRPRegistrationPolicy(srv.URL(), azcore.NewAnonymousCredential(), testRPRegistrationOptions(srv))) + req, err := runtime.NewRequest(context.Background(), http.MethodGet, srv.URL()) + if err != nil { + t.Fatal(err) + } + // log only RP registration + log.SetClassifications(LogRPRegistration) + defer func() { + // reset logging + log.SetClassifications() + }() + log.SetListener(func(cls log.Classification, msg string) { + t.Fatalf("unexpected log entry %s: %s", cls, msg) + }) + resp, err := pl.Do(req) + if err != nil { + t.Fatal(err) + } + if resp.StatusCode != http.StatusOK { + t.Fatalf("unexpected status code %d", resp.StatusCode) + } +} + +func TestRPRegistrationPolicy409Other(t *testing.T) { + const failedResp = `{ + "error":{ + "code":"CannotDoTheThing", + "message":"Something failed in your API call.", + "details":[{ + "code":"ThisIsForTesting", + "message":"This is fake." + } + ] + } + }` + srv, close := mock.NewServer() + defer close() + // test getting a 409 but not due to registration required + srv.AppendResponse(mock.WithStatusCode(http.StatusConflict), mock.WithBody([]byte(failedResp))) + pl := runtime.NewPipeline(srv, NewRPRegistrationPolicy(srv.URL(), azcore.NewAnonymousCredential(), testRPRegistrationOptions(srv))) + req, err := runtime.NewRequest(context.Background(), http.MethodGet, srv.URL()) + if err != nil { + t.Fatal(err) + } + // log only RP registration + log.SetClassifications(LogRPRegistration) + defer func() { + // reset logging + log.SetClassifications() + }() + log.SetListener(func(cls log.Classification, msg string) { + t.Fatalf("unexpected log entry %s: %s", cls, msg) + }) + resp, err := pl.Do(req) + if err != nil { + t.Fatal(err) + } + if resp.StatusCode != http.StatusConflict { + t.Fatalf("unexpected status code %d", resp.StatusCode) + } +} + +func TestRPRegistrationPolicyTimesOut(t *testing.T) { + srv, close := mock.NewServer() + defer close() + // initial response that RP is unregistered + srv.AppendResponse(mock.WithStatusCode(http.StatusConflict), mock.WithBody([]byte(rpUnregisteredResp))) + // polling responses to Register() and Get(), in progress but slow + // tests registration takes too long, times out + srv.RepeatResponse(10, mock.WithStatusCode(http.StatusOK), mock.WithBody([]byte(rpRegisteringResp)), mock.WithSlowResponse(400*time.Millisecond)) + pl := runtime.NewPipeline(srv, NewRPRegistrationPolicy(srv.URL(), azcore.NewAnonymousCredential(), testRPRegistrationOptions(srv))) + req, err := runtime.NewRequest(context.Background(), http.MethodGet, runtime.JoinPaths(srv.URL(), requestEndpoint)) + if err != nil { + t.Fatal(err) + } + // log only RP registration + log.SetClassifications(LogRPRegistration) + defer func() { + // reset logging + log.SetClassifications() + }() + logEntries := 0 + log.SetListener(func(cls log.Classification, msg string) { + logEntries++ + }) + resp, err := pl.Do(req) + if !errors.Is(err, context.DeadlineExceeded) { + t.Fatalf("expected DeadlineExceeded, got %v", err) + } + // should be three entries + // 1st is for start + // 2nd is for first response to get state + // 3rd is the deadline exceeded error + if logEntries != 3 { + t.Fatalf("expected 3 log entries, got %d", logEntries) + } + // we should get the response from the original request + if resp.StatusCode != http.StatusConflict { + t.Fatalf("unexpected status code %d", resp.StatusCode) + } +} + +func TestRPRegistrationPolicyExceedsAttempts(t *testing.T) { + srv, close := mock.NewServer() + defer close() + // add a cycle of unregistered->registered so that we keep retrying and hit the cap + for i := 0; i < 4; i++ { + // initial response that RP is unregistered + srv.AppendResponse(mock.WithStatusCode(http.StatusConflict), mock.WithBody([]byte(rpUnregisteredResp))) + // polling responses to Register() and Get(), in progress + srv.RepeatResponse(2, mock.WithStatusCode(http.StatusOK), mock.WithBody([]byte(rpRegisteringResp))) + // polling response, successful registration + srv.AppendResponse(mock.WithStatusCode(http.StatusOK), mock.WithBody([]byte(rpRegisteredResp))) + } + pl := runtime.NewPipeline(srv, NewRPRegistrationPolicy(srv.URL(), azcore.NewAnonymousCredential(), testRPRegistrationOptions(srv))) + req, err := runtime.NewRequest(context.Background(), http.MethodGet, runtime.JoinPaths(srv.URL(), requestEndpoint)) + if err != nil { + t.Fatal(err) + } + // log only RP registration + log.SetClassifications(LogRPRegistration) + defer func() { + // reset logging + log.SetClassifications() + }() + logEntries := 0 + log.SetListener(func(cls log.Classification, msg string) { + logEntries++ + }) + resp, err := pl.Do(req) + if err == nil { + t.Fatal("unexpected nil error") + } + if !strings.HasPrefix(err.Error(), "exceeded attempts to register Microsoft.Storage") { + t.Fatalf("unexpected error message %s", err.Error()) + } + if resp.StatusCode != http.StatusConflict { + t.Fatalf("unexpected status code %d:", resp.StatusCode) + } + if resp.Request.URL.Path != requestEndpoint { + t.Fatalf("unexpected path in response %s", resp.Request.URL.Path) + } + // should be 4 entries for each attempt, total 12 entries + // 1st is for start + // 2nd is for first response to get state + // 3rd is when state transitions to success + // 4th is for end + if logEntries != 12 { + t.Fatalf("expected 12 log entries, got %d", logEntries) + } +} + +// test cancelling registration +func TestRPRegistrationPolicyCanCancel(t *testing.T) { + srv, close := mock.NewServer() + defer close() + // initial response that RP is unregistered + srv.AppendResponse(mock.WithStatusCode(http.StatusConflict), mock.WithBody([]byte(rpUnregisteredResp))) + // polling responses to Register() and Get(), in progress but slow so we have time to cancel + srv.RepeatResponse(10, mock.WithStatusCode(http.StatusOK), mock.WithBody([]byte(rpRegisteringResp)), mock.WithSlowResponse(300*time.Millisecond)) + opts := RegistrationOptions{} + opts.HTTPClient = srv + pl := runtime.NewPipeline(srv, NewRPRegistrationPolicy(srv.URL(), azcore.NewAnonymousCredential(), &opts)) + // log only RP registration + log.SetClassifications(LogRPRegistration) + defer func() { + // reset logging + log.SetClassifications() + }() + logEntries := 0 + log.SetListener(func(cls log.Classification, msg string) { + logEntries++ + }) + + wg := &sync.WaitGroup{} + wg.Add(1) + + ctx, cancel := context.WithCancel(context.Background()) + var resp *http.Response + var err error + go func() { + defer wg.Done() + // create request and start pipeline + var req *policy.Request + req, err = runtime.NewRequest(ctx, http.MethodGet, runtime.JoinPaths(srv.URL(), requestEndpoint)) + if err != nil { + return + } + resp, err = pl.Do(req) + }() + + // wait for a bit then cancel the operation + time.Sleep(500 * time.Millisecond) + cancel() + wg.Wait() + if !errors.Is(err, context.Canceled) { + t.Fatalf("expected Canceled error, got %v", err) + } + // there should be 1 or 2 entries depending on the timing + if logEntries == 0 { + t.Fatal("didn't get any log entries") + } + // should have original response + if resp.StatusCode != http.StatusConflict { + t.Fatalf("unexpected status code %d", resp.StatusCode) + } +} + +func TestRPRegistrationPolicyDisabled(t *testing.T) { + srv, close := mock.NewServer() + defer close() + // initial response that RP is unregistered + srv.AppendResponse(mock.WithStatusCode(http.StatusConflict), mock.WithBody([]byte(rpUnregisteredResp))) + ops := testRPRegistrationOptions(srv) + ops.MaxAttempts = -1 + pl := runtime.NewPipeline(srv, NewRPRegistrationPolicy(srv.URL(), azcore.NewAnonymousCredential(), ops)) + req, err := runtime.NewRequest(context.Background(), http.MethodGet, runtime.JoinPaths(srv.URL(), requestEndpoint)) + if err != nil { + t.Fatal(err) + } + // log only RP registration + log.SetClassifications(LogRPRegistration) + defer func() { + // reset logging + log.SetClassifications() + }() + logEntries := 0 + log.SetListener(func(cls log.Classification, msg string) { + logEntries++ + }) + resp, err := pl.Do(req) + if err != nil { + t.Fatal(err) + } + if resp.StatusCode != http.StatusConflict { + t.Fatalf("unexpected status code %d:", resp.StatusCode) + } + // shouldn't be any log entries + if logEntries != 0 { + t.Fatalf("expected 0 log entries, got %d", logEntries) + } +} diff --git a/sdk/azcore/arm/runtime/poller.go b/sdk/azcore/arm/runtime/poller.go new file mode 100644 index 000000000000..f4e6df175bb0 --- /dev/null +++ b/sdk/azcore/arm/runtime/poller.go @@ -0,0 +1,81 @@ +//go:build go1.16 +// +build go1.16 + +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package runtime + +import ( + "encoding/json" + "errors" + "fmt" + "net/http" + + "github.com/Azure/azure-sdk-for-go/sdk/azcore/arm/internal/pollers/async" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/arm/internal/pollers/body" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/arm/internal/pollers/loc" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/pipeline" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/pollers" + "github.com/Azure/azure-sdk-for-go/sdk/internal/log" +) + +// NewPoller creates a Poller based on the provided initial response. +// pollerID - a unique identifier for an LRO. it's usually the client.Method string. +func NewPoller(pollerID string, finalState string, resp *http.Response, pl pipeline.Pipeline, eu func(*http.Response) error) (*pollers.Poller, error) { + // this is a back-stop in case the swagger is incorrect (i.e. missing one or more status codes for success). + // ideally the codegen should return an error if the initial response failed and not even create a poller. + if !pollers.StatusCodeValid(resp) { + return nil, errors.New("the LRO failed or was cancelled") + } + // determine the polling method + var lro pollers.Operation + var err error + if async.Applicable(resp) { + lro, err = async.New(resp, finalState, pollerID) + } else if loc.Applicable(resp) { + lro, err = loc.New(resp, pollerID) + } else if body.Applicable(resp) { + // must test body poller last as it's a subset of the other pollers. + // TODO: this is ambiguous for PATCH/PUT if it returns a 200 with no polling headers (sync completion) + lro, err = body.New(resp, pollerID) + } else if m := resp.Request.Method; resp.StatusCode == http.StatusAccepted && (m == http.MethodDelete || m == http.MethodPost) { + // if we get here it means we have a 202 with no polling headers. + // for DELETE and POST this is a hard error per ARM RPC spec. + return nil, errors.New("response is missing polling URL") + } else { + lro = &pollers.NopPoller{} + } + if err != nil { + return nil, err + } + return pollers.NewPoller(lro, resp, pl, eu), nil +} + +// NewPollerFromResumeToken creates a Poller from a resume token string. +// pollerID - a unique identifier for an LRO. it's usually the client.Method string. +func NewPollerFromResumeToken(pollerID string, token string, pl pipeline.Pipeline, eu func(*http.Response) error) (*pollers.Poller, error) { + kind, err := pollers.KindFromToken(pollerID, token) + if err != nil { + return nil, err + } + // now rehydrate the poller based on the encoded poller type + var lro pollers.Operation + switch kind { + case async.Kind: + log.Writef(log.LongRunningOperation, "Resuming %s poller.", async.Kind) + lro = &async.Poller{} + case loc.Kind: + log.Writef(log.LongRunningOperation, "Resuming %s poller.", loc.Kind) + lro = &loc.Poller{} + case body.Kind: + log.Writef(log.LongRunningOperation, "Resuming %s poller.", body.Kind) + lro = &body.Poller{} + default: + return nil, fmt.Errorf("unhandled poller type %s", kind) + } + if err = json.Unmarshal([]byte(token), lro); err != nil { + return nil, err + } + return pollers.NewPoller(lro, nil, pl, eu), nil +} diff --git a/sdk/azcore/arm/runtime/poller_test.go b/sdk/azcore/arm/runtime/poller_test.go new file mode 100644 index 000000000000..6ac0429eeb1d --- /dev/null +++ b/sdk/azcore/arm/runtime/poller_test.go @@ -0,0 +1,308 @@ +//go:build go1.16 +// +build go1.16 + +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package runtime + +import ( + "context" + "io" + "io/ioutil" + "net/http" + "reflect" + "strings" + "testing" + "time" + + "github.com/Azure/azure-sdk-for-go/sdk/azcore/arm/internal/pollers/async" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/arm/internal/pollers/body" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/arm/internal/pollers/loc" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/pipeline" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/pollers" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/shared" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/runtime" + "github.com/Azure/azure-sdk-for-go/sdk/internal/mock" +) + +const ( + provStateStarted = `{ "properties": { "provisioningState": "Started" } }` + provStateUpdating = `{ "properties": { "provisioningState": "Updating" } }` + provStateSucceeded = `{ "properties": { "provisioningState": "Succeeded" }, "field": "value" }` + provStateFailed = `{ "properties": { "provisioningState": "Failed" } }` //nolint + statusInProgress = `{ "status": "InProgress" }` + statusSucceeded = `{ "status": "Succeeded" }` + statusCanceled = `{ "status": "Canceled" }` + successResp = `{ "field": "value" }` + errorResp = `{ "error": "the operation failed" }` +) + +type mockType struct { + Field *string `json:"field,omitempty"` +} + +type mockError struct { + Msg string `json:"error"` +} + +func (m mockError) Error() string { + return m.Msg +} + +func getPipeline(srv *mock.Server) pipeline.Pipeline { + return runtime.NewPipeline( + srv, + runtime.NewLogPolicy(nil)) +} + +func handleError(resp *http.Response) error { + var me mockError + if err := runtime.UnmarshalAsJSON(resp, &me); err != nil { + return err + } + return me +} + +func initialResponse(method, u string, resp io.Reader) *http.Response { + req, err := http.NewRequest(method, u, nil) + if err != nil { + panic(err) + } + return &http.Response{ + Body: ioutil.NopCloser(resp), + ContentLength: -1, + Header: http.Header{}, + Request: req, + } +} + +func TestNewPollerAsync(t *testing.T) { + srv, close := mock.NewServer() + defer close() + srv.AppendResponse(mock.WithBody([]byte(statusInProgress))) + srv.AppendResponse(mock.WithBody([]byte(statusSucceeded))) + srv.AppendResponse(mock.WithBody([]byte(successResp))) + resp := initialResponse(http.MethodPut, srv.URL(), strings.NewReader(provStateStarted)) + resp.Header.Set(shared.HeaderAzureAsync, srv.URL()) + resp.StatusCode = http.StatusCreated + pl := getPipeline(srv) + poller, err := NewPoller("pollerID", "", resp, pl, handleError) + if err != nil { + t.Fatal(err) + } + if pt := pollers.PollerType(poller); pt != reflect.TypeOf(&async.Poller{}) { + t.Fatalf("unexpected poller type %s", pt.String()) + } + tk, err := poller.ResumeToken() + if err != nil { + t.Fatal(err) + } + poller, err = NewPollerFromResumeToken("pollerID", tk, pl, handleError) + if err != nil { + t.Fatal(err) + } + var result mockType + _, err = poller.PollUntilDone(context.Background(), 10*time.Millisecond, &result) + if err != nil { + t.Fatal(err) + } + if v := *result.Field; v != "value" { + t.Fatalf("unexpected value %s", v) + } +} + +func TestNewPollerBody(t *testing.T) { + srv, close := mock.NewServer() + defer close() + srv.AppendResponse(mock.WithBody([]byte(provStateUpdating)), mock.WithHeader("Retry-After", "1")) + srv.AppendResponse(mock.WithBody([]byte(provStateSucceeded))) + resp := initialResponse(http.MethodPatch, srv.URL(), strings.NewReader(provStateStarted)) + resp.StatusCode = http.StatusCreated + pl := getPipeline(srv) + poller, err := NewPoller("pollerID", "", resp, pl, handleError) + if err != nil { + t.Fatal(err) + } + if pt := pollers.PollerType(poller); pt != reflect.TypeOf(&body.Poller{}) { + t.Fatalf("unexpected poller type %s", pt.String()) + } + tk, err := poller.ResumeToken() + if err != nil { + t.Fatal(err) + } + poller, err = NewPollerFromResumeToken("pollerID", tk, pl, handleError) + if err != nil { + t.Fatal(err) + } + var result mockType + _, err = poller.PollUntilDone(context.Background(), 10*time.Millisecond, &result) + if err != nil { + t.Fatal(err) + } + if v := *result.Field; v != "value" { + t.Fatalf("unexpected value %s", v) + } +} + +func TestNewPollerLoc(t *testing.T) { + srv, close := mock.NewServer() + defer close() + srv.AppendResponse(mock.WithStatusCode(http.StatusAccepted)) + srv.AppendResponse(mock.WithBody([]byte(successResp))) + resp := initialResponse(http.MethodPatch, srv.URL(), strings.NewReader(provStateStarted)) + resp.Header.Set(shared.HeaderLocation, srv.URL()) + resp.StatusCode = http.StatusAccepted + pl := getPipeline(srv) + poller, err := NewPoller("pollerID", "", resp, pl, handleError) + if err != nil { + t.Fatal(err) + } + if pt := pollers.PollerType(poller); pt != reflect.TypeOf(&loc.Poller{}) { + t.Fatalf("unexpected poller type %s", pt.String()) + } + tk, err := poller.ResumeToken() + if err != nil { + t.Fatal(err) + } + poller, err = NewPollerFromResumeToken("pollerID", tk, pl, handleError) + if err != nil { + t.Fatal(err) + } + var result mockType + _, err = poller.PollUntilDone(context.Background(), 10*time.Millisecond, &result) + if err != nil { + t.Fatal(err) + } + if v := *result.Field; v != "value" { + t.Fatalf("unexpected value %s", v) + } +} + +func TestNewPollerInitialRetryAfter(t *testing.T) { + srv, close := mock.NewServer() + defer close() + srv.AppendResponse(mock.WithBody([]byte(statusInProgress))) + srv.AppendResponse(mock.WithBody([]byte(statusSucceeded))) + srv.AppendResponse(mock.WithBody([]byte(successResp))) + resp := initialResponse(http.MethodPut, srv.URL(), strings.NewReader(provStateStarted)) + resp.Header.Set(shared.HeaderAzureAsync, srv.URL()) + resp.Header.Set("Retry-After", "1") + resp.StatusCode = http.StatusCreated + pl := getPipeline(srv) + poller, err := NewPoller("pollerID", "", resp, pl, handleError) + if err != nil { + t.Fatal(err) + } + if pt := pollers.PollerType(poller); pt != reflect.TypeOf(&async.Poller{}) { + t.Fatalf("unexpected poller type %s", pt.String()) + } + var result mockType + _, err = poller.PollUntilDone(context.Background(), 10*time.Millisecond, &result) + if err != nil { + t.Fatal(err) + } + if v := *result.Field; v != "value" { + t.Fatalf("unexpected value %s", v) + } +} + +func TestNewPollerCanceled(t *testing.T) { + srv, close := mock.NewServer() + defer close() + srv.AppendResponse(mock.WithBody([]byte(statusInProgress))) + srv.AppendResponse(mock.WithBody([]byte(statusCanceled)), mock.WithStatusCode(http.StatusOK)) + resp := initialResponse(http.MethodPut, srv.URL(), strings.NewReader(provStateStarted)) + resp.Header.Set(shared.HeaderAzureAsync, srv.URL()) + resp.StatusCode = http.StatusCreated + pl := getPipeline(srv) + poller, err := NewPoller("pollerID", "", resp, pl, handleError) + if err != nil { + t.Fatal(err) + } + if pt := pollers.PollerType(poller); pt != reflect.TypeOf(&async.Poller{}) { + t.Fatalf("unexpected poller type %s", pt.String()) + } + _, err = poller.Poll(context.Background()) + if err != nil { + t.Fatal(err) + } + _, err = poller.Poll(context.Background()) + if err == nil { + t.Fatal("unexpected nil error") + } +} + +func TestNewPollerFailedWithError(t *testing.T) { + srv, close := mock.NewServer() + defer close() + srv.AppendResponse(mock.WithBody([]byte(statusInProgress))) + srv.AppendResponse(mock.WithBody([]byte(errorResp)), mock.WithStatusCode(http.StatusBadRequest)) + resp := initialResponse(http.MethodPut, srv.URL(), strings.NewReader(provStateStarted)) + resp.Header.Set(shared.HeaderAzureAsync, srv.URL()) + resp.StatusCode = http.StatusCreated + pl := getPipeline(srv) + poller, err := NewPoller("pollerID", "", resp, pl, handleError) + if err != nil { + t.Fatal(err) + } + if pt := pollers.PollerType(poller); pt != reflect.TypeOf(&async.Poller{}) { + t.Fatalf("unexpected poller type %s", pt.String()) + } + var result mockType + _, err = poller.PollUntilDone(context.Background(), 10*time.Millisecond, &result) + if err == nil { + t.Fatal(err) + } + if _, ok := err.(mockError); !ok { + t.Fatalf("unexpected error type %T", err) + } +} + +func TestNewPollerSuccessNoContent(t *testing.T) { + srv, close := mock.NewServer() + defer close() + srv.AppendResponse(mock.WithBody([]byte(provStateUpdating))) + srv.AppendResponse(mock.WithStatusCode(http.StatusNoContent)) + resp := initialResponse(http.MethodPatch, srv.URL(), strings.NewReader(provStateStarted)) + resp.StatusCode = http.StatusCreated + pl := getPipeline(srv) + poller, err := NewPoller("pollerID", "", resp, pl, handleError) + if err != nil { + t.Fatal(err) + } + if pt := pollers.PollerType(poller); pt != reflect.TypeOf(&body.Poller{}) { + t.Fatalf("unexpected poller type %s", pt.String()) + } + tk, err := poller.ResumeToken() + if err != nil { + t.Fatal(err) + } + poller, err = NewPollerFromResumeToken("pollerID", tk, pl, handleError) + if err != nil { + t.Fatal(err) + } + var result mockType + _, err = poller.PollUntilDone(context.Background(), 10*time.Millisecond, &result) + if err != nil { + t.Fatal(err) + } + if result.Field != nil { + t.Fatal("expected nil result") + } +} + +func TestNewPollerFail202NoHeaders(t *testing.T) { + srv, close := mock.NewServer() + defer close() + resp := initialResponse(http.MethodDelete, srv.URL(), http.NoBody) + resp.StatusCode = http.StatusAccepted + pl := getPipeline(srv) + poller, err := NewPoller("pollerID", "", resp, pl, handleError) + if err == nil { + t.Fatal("unexpected nil error") + } + if poller != nil { + t.Fatal("expected nil poller") + } +} diff --git a/sdk/azcore/core.go b/sdk/azcore/core.go index cfd2ede845ef..660b0cbe07ca 100644 --- a/sdk/azcore/core.go +++ b/sdk/azcore/core.go @@ -7,111 +7,11 @@ package azcore import ( - "errors" - "io" - "net/http" "reflect" -) -const ( - headerContentLength = "Content-Length" - headerContentType = "Content-Type" - headerOperationLocation = "Operation-Location" - headerLocation = "Location" - headerRetryAfter = "Retry-After" - headerUserAgent = "User-Agent" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/pollers" ) -// Policy represents an extensibility point for the Pipeline that can mutate the specified -// Request and react to the received Response. -type Policy interface { - // Do applies the policy to the specified Request. When implementing a Policy, mutate the - // request before calling req.Next() to move on to the next policy, and respond to the result - // before returning to the caller. - Do(req *Request) (*http.Response, error) -} - -// policyFunc is a type that implements the Policy interface. -// Use this type when implementing a stateless policy as a first-class function. -type policyFunc func(*Request) (*http.Response, error) - -// Do implements the Policy interface on PolicyFunc. -func (pf policyFunc) Do(req *Request) (*http.Response, error) { - return pf(req) -} - -// Transporter represents an HTTP pipeline transport used to send HTTP requests and receive responses. -type Transporter interface { - // Do sends the HTTP request and returns the HTTP response or error. - Do(req *http.Request) (*http.Response, error) -} - -// used to adapt a TransportPolicy to a Policy -type transportPolicy struct { - trans Transporter -} - -func (tp transportPolicy) Do(req *Request) (*http.Response, error) { - resp, err := tp.trans.Do(req.Request) - if err != nil { - return nil, err - } else if resp == nil { - // there was no response and no error (rare but can happen) - // this ensures the retry policy will retry the request - return nil, errors.New("received nil response") - } - return resp, nil -} - -// Pipeline represents a primitive for sending HTTP requests and receiving responses. -// Its behavior can be extended by specifying policies during construction. -type Pipeline struct { - policies []Policy -} - -// NewPipeline creates a new Pipeline object from the specified Transport and Policies. -// If no transport is provided then the default *http.Client transport will be used. -func NewPipeline(transport Transporter, policies ...Policy) Pipeline { - if transport == nil { - transport = defaultHTTPClient - } - // transport policy must always be the last in the slice - policies = append(policies, policyFunc(httpHeaderPolicy), policyFunc(bodyDownloadPolicy), transportPolicy{trans: transport}) - return Pipeline{ - policies: policies, - } -} - -// Do is called for each and every HTTP request. It passes the request through all -// the Policy objects (which can transform the Request's URL/query parameters/headers) -// and ultimately sends the transformed HTTP request over the network. -func (p Pipeline) Do(req *Request) (*http.Response, error) { - if err := req.valid(); err != nil { - return nil, err - } - req.policies = p.policies - return req.Next() -} - -// ReadSeekCloser is the interface that groups the io.ReadCloser and io.Seeker interfaces. -type ReadSeekCloser interface { - io.ReadCloser - io.Seeker -} - -type nopCloser struct { - io.ReadSeeker -} - -func (n nopCloser) Close() error { - return nil -} - -// NopCloser returns a ReadSeekCloser with a no-op close method wrapping the provided io.ReadSeeker. -func NopCloser(rs io.ReadSeeker) ReadSeekCloser { - return nopCloser{rs} -} - // holds sentinel values used to send nulls var nullables map[reflect.Type]interface{} = map[reflect.Type]interface{}{} @@ -159,3 +59,6 @@ func IsNullValue(v interface{}) bool { // no sentinel object for this *t return false } + +// Poller encapsulates state and logic for polling on long-running operations. +type Poller = pollers.Poller diff --git a/sdk/azcore/credential.go b/sdk/azcore/credential.go index 441f3d0bc37c..d55a356964f9 100644 --- a/sdk/azcore/credential.go +++ b/sdk/azcore/credential.go @@ -9,31 +9,23 @@ package azcore import ( "context" "time" -) -// AuthenticationOptions contains various options used to create a credential policy. -type AuthenticationOptions struct { - // TokenRequest is a TokenRequestOptions that includes a scopes field which contains - // the list of OAuth2 authentication scopes used when requesting a token. - // This field is ignored for other forms of authentication (e.g. shared key). - TokenRequest TokenRequestOptions - // AuxiliaryTenants contains a list of additional tenant IDs to be used to authenticate - // in cross-tenant applications. - AuxiliaryTenants []string -} + "github.com/Azure/azure-sdk-for-go/sdk/azcore/policy" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/runtime" +) // Credential represents any credential type. type Credential interface { // AuthenticationPolicy returns a policy that requests the credential and applies it to the HTTP request. - NewAuthenticationPolicy(options AuthenticationOptions) Policy + NewAuthenticationPolicy(options runtime.AuthenticationOptions) policy.Policy } // credentialFunc is a type that implements the Credential interface. // Use this type when implementing a stateless credential as a first-class function. -type credentialFunc func(options AuthenticationOptions) Policy +type credentialFunc func(options runtime.AuthenticationOptions) policy.Policy // AuthenticationPolicy implements the Credential interface on credentialFunc. -func (cf credentialFunc) NewAuthenticationPolicy(options AuthenticationOptions) Policy { +func (cf credentialFunc) NewAuthenticationPolicy(options runtime.AuthenticationOptions) policy.Policy { return cf(options) } @@ -41,7 +33,7 @@ func (cf credentialFunc) NewAuthenticationPolicy(options AuthenticationOptions) type TokenCredential interface { Credential // GetToken requests an access token for the specified set of scopes. - GetToken(ctx context.Context, options TokenRequestOptions) (*AccessToken, error) + GetToken(ctx context.Context, options policy.TokenRequestOptions) (*AccessToken, error) } // AccessToken represents an Azure service bearer access token with expiry information. @@ -49,12 +41,3 @@ type AccessToken struct { Token string ExpiresOn time.Time } - -// TokenRequestOptions contain specific parameter that may be used by credentials types when attempting to get a token. -type TokenRequestOptions struct { - // Scopes contains the list of permission scopes required for the token. - Scopes []string - // TenantID contains the tenant ID to use in a multi-tenant authentication scenario, if TenantID is set - // it will override the tenant ID that was added at credential creation time. - TenantID string -} diff --git a/sdk/azcore/error.go b/sdk/azcore/error.go deleted file mode 100644 index e2547faf808c..000000000000 --- a/sdk/azcore/error.go +++ /dev/null @@ -1,70 +0,0 @@ -//go:build go1.16 -// +build go1.16 - -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -package azcore - -import ( - "net/http" -) - -var ( - // StackFrameCount contains the number of stack frames to include when a trace is being collected. - StackFrameCount = 32 -) - -// HTTPResponse provides access to an HTTP response when available. -// Errors returned from failed API calls will implement this interface. -// Use errors.As() to access this interface in the error chain. -// If there was no HTTP response then this interface will be omitted -// from any error in the chain. -type HTTPResponse interface { - RawResponse() *http.Response -} - -// NonRetriableError represents a non-transient error. This works in -// conjunction with the retry policy, indicating that the error condition -// is idempotent, so no retries will be attempted. -// Use errors.As() to access this interface in the error chain. -type NonRetriableError interface { - error - NonRetriable() -} - -// NewResponseError wraps the specified error with an error that provides access to an HTTP response. -// If an HTTP request returns a non-successful status code, wrap the response and the associated error -// in this error type so that callers can access the underlying *http.Response as required. -// DO NOT wrap failed HTTP requests that returned an error and no response with this type. -func NewResponseError(inner error, resp *http.Response) error { - return &responseError{inner: inner, resp: resp} -} - -type responseError struct { - inner error - resp *http.Response -} - -// Error implements the error interface for type ResponseError. -func (e *responseError) Error() string { - return e.inner.Error() -} - -// Unwrap returns the inner error. -func (e *responseError) Unwrap() error { - return e.inner -} - -// RawResponse returns the HTTP response associated with this error. -func (e *responseError) RawResponse() *http.Response { - return e.resp -} - -// NonRetriable indicates this error is non-transient. -func (e *responseError) NonRetriable() { - // marker method -} - -var _ HTTPResponse = (*responseError)(nil) -var _ NonRetriableError = (*responseError)(nil) diff --git a/sdk/azcore/errors.go b/sdk/azcore/errors.go new file mode 100644 index 000000000000..222c8f85f21b --- /dev/null +++ b/sdk/azcore/errors.go @@ -0,0 +1,26 @@ +//go:build go1.16 +// +build go1.16 + +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package azcore + +import ( + "net/http" + + "github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/shared" + "github.com/Azure/azure-sdk-for-go/sdk/internal/errorinfo" +) + +// HTTPResponse provides access to an HTTP response when available. +// Errors returned from failed API calls will implement this interface. +// Use errors.As() to access this interface in the error chain. +// If there was no HTTP response then this interface will be omitted +// from any error in the chain. +type HTTPResponse interface { + RawResponse() *http.Response +} + +var _ HTTPResponse = (*shared.ResponseError)(nil) +var _ errorinfo.NonRetriable = (*shared.ResponseError)(nil) diff --git a/sdk/azcore/example_test.go b/sdk/azcore/example_test.go index 0506a3f75bac..b13cb08a11fd 100644 --- a/sdk/azcore/example_test.go +++ b/sdk/azcore/example_test.go @@ -8,58 +8,23 @@ package azcore_test import ( - "context" "encoding/json" "fmt" - "io/ioutil" - "log" - "net/http" - "strings" "github.com/Azure/azure-sdk-for-go/sdk/azcore" - azlog "github.com/Azure/azure-sdk-for-go/sdk/azcore/log" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/log" ) -func ExamplePipeline_Do() { - req, err := azcore.NewRequest(context.Background(), http.MethodGet, "https://github.com/robots.txt") - if err != nil { - log.Fatal(err) - } - pipeline := azcore.NewPipeline(nil) - resp, err := pipeline.Do(req) - if err != nil { - log.Fatal(err) - } - robots, err := ioutil.ReadAll(resp.Body) - resp.Body.Close() - if err != nil { - log.Fatal(err) - } - fmt.Printf("%s", robots) -} - -func ExampleRequest_SetBody() { - req, err := azcore.NewRequest(context.Background(), http.MethodPut, "https://contoso.com/some/endpoint") - if err != nil { - log.Fatal(err) - } - body := strings.NewReader("this is seekable content to be uploaded") - err = req.SetBody(azcore.NopCloser(body), "text/plain") - if err != nil { - log.Fatal(err) - } -} - // false positive by linter func ExampleSetClassifications() { //nolint:govet // only log HTTP requests and responses - azlog.SetClassifications(azlog.Request, azlog.Response) + log.SetClassifications(log.Request, log.Response) } // false positive by linter func ExampleSetListener() { //nolint:govet // a simple logger that writes to stdout - azlog.SetListener(func(cls azlog.Classification, msg string) { + log.SetListener(func(cls log.Classification, msg string) { fmt.Printf("%s: %s\n", cls, msg) }) } diff --git a/sdk/azcore/internal/pipeline/pipeline.go b/sdk/azcore/internal/pipeline/pipeline.go new file mode 100644 index 000000000000..e2c9f115a1d7 --- /dev/null +++ b/sdk/azcore/internal/pipeline/pipeline.go @@ -0,0 +1,93 @@ +//go:build go1.16 +// +build go1.16 + +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package pipeline + +import ( + "errors" + "fmt" + "net/http" + + "golang.org/x/net/http/httpguts" +) + +// Policy represents an extensibility point for the Pipeline that can mutate the specified +// Request and react to the received Response. +type Policy interface { + // Do applies the policy to the specified Request. When implementing a Policy, mutate the + // request before calling req.Next() to move on to the next policy, and respond to the result + // before returning to the caller. + Do(req *Request) (*http.Response, error) +} + +// Pipeline represents a primitive for sending HTTP requests and receiving responses. +// Its behavior can be extended by specifying policies during construction. +type Pipeline struct { + policies []Policy +} + +// Transporter represents an HTTP pipeline transport used to send HTTP requests and receive responses. +type Transporter interface { + // Do sends the HTTP request and returns the HTTP response or error. + Do(req *http.Request) (*http.Response, error) +} + +// used to adapt a TransportPolicy to a Policy +type transportPolicy struct { + trans Transporter +} + +func (tp transportPolicy) Do(req *Request) (*http.Response, error) { + if tp.trans == nil { + return nil, errors.New("missing transporter") + } + resp, err := tp.trans.Do(req.Raw()) + if err != nil { + return nil, err + } else if resp == nil { + // there was no response and no error (rare but can happen) + // this ensures the retry policy will retry the request + return nil, errors.New("received nil response") + } + return resp, nil +} + +// NewPipeline creates a new Pipeline object from the specified Policies. +func NewPipeline(transport Transporter, policies ...Policy) Pipeline { + // transport policy must always be the last in the slice + policies = append(policies, transportPolicy{trans: transport}) + return Pipeline{ + policies: policies, + } +} + +// Do is called for each and every HTTP request. It passes the request through all +// the Policy objects (which can transform the Request's URL/query parameters/headers) +// and ultimately sends the transformed HTTP request over the network. +func (p Pipeline) Do(req *Request) (*http.Response, error) { + if req == nil { + return nil, errors.New("request cannot be nil") + } + // check copied from Transport.roundTrip() + for k, vv := range req.Raw().Header { + if !httpguts.ValidHeaderFieldName(k) { + if req.Raw().Body != nil { + req.Raw().Body.Close() + } + return nil, fmt.Errorf("invalid header field name %q", k) + } + for _, v := range vv { + if !httpguts.ValidHeaderFieldValue(v) { + if req.Raw().Body != nil { + req.Raw().Body.Close() + } + return nil, fmt.Errorf("invalid header field value %q for key %v", v, k) + } + } + } + req.policies = p.policies + return req.Next() +} diff --git a/sdk/azcore/internal/pipeline/pipeline_test.go b/sdk/azcore/internal/pipeline/pipeline_test.go new file mode 100644 index 000000000000..81bc22e698af --- /dev/null +++ b/sdk/azcore/internal/pipeline/pipeline_test.go @@ -0,0 +1,103 @@ +//go:build go1.16 +// +build go1.16 + +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package pipeline + +import ( + "context" + "errors" + "net/http" + "testing" +) + +func TestPipelineErrors(t *testing.T) { + pl := NewPipeline(nil) + resp, err := pl.Do(nil) + if err == nil { + t.Fatal("unexpected nil error") + } + if resp != nil { + t.Fatal("expected nil response") + } + req, err := NewRequest(context.Background(), http.MethodGet, testURL) + if err != nil { + t.Fatal(err) + } + resp, err = pl.Do(req) + if err == nil { + t.Fatal("unexpected nil error") + } + if resp != nil { + t.Fatal("expected nil response") + } + req.Raw().Header["Invalid"] = []string{string([]byte{0})} + resp, err = pl.Do(req) + if err == nil { + t.Fatal("unexpected nil error") + } + if resp != nil { + t.Fatal("expected nil response") + } + req, err = NewRequest(context.Background(), http.MethodGet, testURL) + if err != nil { + t.Fatal(err) + } + req.Raw().Header["Inv alid"] = []string{"value"} + resp, err = pl.Do(req) + if err == nil { + t.Fatal("unexpected nil error") + } + if resp != nil { + t.Fatal("expected nil response") + } +} + +type mockTransport struct { + succeed bool + both bool +} + +func (m *mockTransport) Do(*http.Request) (*http.Response, error) { + if m.both { + return nil, nil + } + if m.succeed { + return &http.Response{StatusCode: http.StatusOK}, nil + } + return nil, errors.New("failed") +} + +func TestPipelineDo(t *testing.T) { + req, err := NewRequest(context.Background(), http.MethodGet, testURL) + if err != nil { + t.Fatal(err) + } + tp := mockTransport{succeed: true} + pl := NewPipeline(&tp) + resp, err := pl.Do(req) + if err != nil { + t.Fatal(err) + } + if sc := resp.StatusCode; sc != http.StatusOK { + t.Fatalf("unexpected status code %d", sc) + } + tp.succeed = false + resp, err = pl.Do(req) + if err == nil { + t.Fatal("unexpected nil error") + } + if resp != nil { + t.Fatal("expected nil response") + } + tp.both = true + resp, err = pl.Do(req) + if err == nil { + t.Fatal("unexpected nil error") + } + if resp != nil { + t.Fatal("expected nil response") + } +} diff --git a/sdk/azcore/internal/pipeline/request.go b/sdk/azcore/internal/pipeline/request.go new file mode 100644 index 000000000000..e261f30429a2 --- /dev/null +++ b/sdk/azcore/internal/pipeline/request.go @@ -0,0 +1,169 @@ +//go:build go1.16 +// +build go1.16 + +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package pipeline + +import ( + "context" + "errors" + "fmt" + "io" + "net/http" + "reflect" + "strconv" + + "github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/shared" +) + +// PolicyFunc is a type that implements the Policy interface. +// Use this type when implementing a stateless policy as a first-class function. +type PolicyFunc func(*Request) (*http.Response, error) + +// Do implements the Policy interface on PolicyFunc. +func (pf PolicyFunc) Do(req *Request) (*http.Response, error) { + return pf(req) +} + +// Request is an abstraction over the creation of an HTTP request as it passes through the pipeline. +// Don't use this type directly, use NewRequest() instead. +type Request struct { + req *http.Request + body io.ReadSeekCloser + policies []Policy + values opValues +} + +type opValues map[reflect.Type]interface{} + +// Set adds/changes a value +func (ov opValues) set(value interface{}) { + ov[reflect.TypeOf(value)] = value +} + +// Get looks for a value set by SetValue first +func (ov opValues) get(value interface{}) bool { + v, ok := ov[reflect.ValueOf(value).Elem().Type()] + if ok { + reflect.ValueOf(value).Elem().Set(reflect.ValueOf(v)) + } + return ok +} + +// NewRequest creates a new Request with the specified input. +func NewRequest(ctx context.Context, httpMethod string, endpoint string) (*Request, error) { + req, err := http.NewRequestWithContext(ctx, httpMethod, endpoint, nil) + if err != nil { + return nil, err + } + if req.URL.Host == "" { + return nil, errors.New("no Host in request URL") + } + if !(req.URL.Scheme == "http" || req.URL.Scheme == "https") { + return nil, fmt.Errorf("unsupported protocol scheme %s", req.URL.Scheme) + } + return &Request{req: req}, nil +} + +// Body returns the original body specified when the Request was created. +func (req *Request) Body() io.ReadSeekCloser { + return req.body +} + +// Raw returns the underlying HTTP request. +func (req *Request) Raw() *http.Request { + return req.req +} + +// Next calls the next policy in the pipeline. +// If there are no more policies, nil and an error are returned. +// This method is intended to be called from pipeline policies. +// To send a request through a pipeline call Pipeline.Do(). +func (req *Request) Next() (*http.Response, error) { + if len(req.policies) == 0 { + return nil, errors.New("no more policies") + } + nextPolicy := req.policies[0] + nextReq := *req + nextReq.policies = nextReq.policies[1:] + return nextPolicy.Do(&nextReq) +} + +// SetOperationValue adds/changes a mutable key/value associated with a single operation. +func (req *Request) SetOperationValue(value interface{}) { + if req.values == nil { + req.values = opValues{} + } + req.values.set(value) +} + +// OperationValue looks for a value set by SetOperationValue(). +func (req *Request) OperationValue(value interface{}) bool { + if req.values == nil { + return false + } + return req.values.get(value) +} + +// SetBody sets the specified ReadSeekCloser as the HTTP request body. +func (req *Request) SetBody(body io.ReadSeekCloser, contentType string) error { + // Set the body and content length. + size, err := body.Seek(0, io.SeekEnd) // Seek to the end to get the stream's size + if err != nil { + return err + } + if size == 0 { + body.Close() + return nil + } + _, err = body.Seek(0, io.SeekStart) + if err != nil { + return err + } + req.Raw().GetBody = func() (io.ReadCloser, error) { + _, err := body.Seek(0, io.SeekStart) // Seek back to the beginning of the stream + return body, err + } + // keep a copy of the original body. this is to handle cases + // where req.Body is replaced, e.g. httputil.DumpRequest and friends. + req.body = body + req.req.Body = body + req.req.ContentLength = size + req.req.Header.Set(shared.HeaderContentType, contentType) + req.req.Header.Set(shared.HeaderContentLength, strconv.FormatInt(size, 10)) + return nil +} + +// SkipBodyDownload will disable automatic downloading of the response body. +func (req *Request) SkipBodyDownload() { + req.SetOperationValue(shared.BodyDownloadPolicyOpValues{Skip: true}) +} + +// RewindBody seeks the request's Body stream back to the beginning so it can be resent when retrying an operation. +func (req *Request) RewindBody() error { + if req.body != nil { + // Reset the stream back to the beginning and restore the body + _, err := req.body.Seek(0, io.SeekStart) + req.req.Body = req.body + return err + } + return nil +} + +// Close closes the request body. +func (req *Request) Close() error { + if req.body == nil { + return nil + } + return req.body.Close() +} + +// Clone returns a deep copy of the request with its context changed to ctx. +func (req *Request) Clone(ctx context.Context) *Request { + r2 := Request{} + r2 = *req + r2.req = req.req.Clone(ctx) + return &r2 +} diff --git a/sdk/azcore/internal/pipeline/request_test.go b/sdk/azcore/internal/pipeline/request_test.go new file mode 100644 index 000000000000..4677417861cc --- /dev/null +++ b/sdk/azcore/internal/pipeline/request_test.go @@ -0,0 +1,139 @@ +//go:build go1.16 +// +build go1.16 + +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package pipeline + +import ( + "context" + "net/http" + "strings" + "testing" + + "github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/shared" +) + +const testURL = "http://test.contoso.com/" + +func TestNewRequest(t *testing.T) { + req, err := NewRequest(context.Background(), http.MethodPost, testURL) + if err != nil { + t.Fatal(err) + } + if m := req.Raw().Method; m != http.MethodPost { + t.Fatalf("unexpected method %s", m) + } + type myValue struct{} + var mv myValue + if req.OperationValue(&mv) { + t.Fatal("expected missing custom operation value") + } + req.SetOperationValue(myValue{}) + if !req.OperationValue(&mv) { + t.Fatal("missing custom operation value") + } +} + +func TestRequestPolicies(t *testing.T) { + req, err := NewRequest(context.Background(), http.MethodPost, testURL) + if err != nil { + t.Fatal(err) + } + resp, err := req.Next() + if err == nil { + t.Fatal("unexpected nil error") + } + if resp != nil { + t.Fatal("expected nil response") + } + req.policies = []Policy{} + resp, err = req.Next() + if err == nil { + t.Fatal("unexpected nil error") + } + if resp != nil { + t.Fatal("expected nil response") + } + testPolicy := func(*Request) (*http.Response, error) { + return &http.Response{}, nil + } + req.policies = []Policy{PolicyFunc(testPolicy)} + resp, err = req.Next() + if err != nil { + t.Fatal(err) + } + if resp == nil { + t.Fatal("unexpected nil response") + } +} + +func TestRequestBody(t *testing.T) { + req, err := NewRequest(context.Background(), http.MethodPost, testURL) + if err != nil { + t.Fatal(err) + } + req.SkipBodyDownload() + if err := req.RewindBody(); err != nil { + t.Fatal(err) + } + if err := req.Close(); err != nil { + t.Fatal(err) + } + if err := req.SetBody(shared.NopCloser(strings.NewReader("test")), "application/text"); err != nil { + t.Fatal(err) + } + if err := req.RewindBody(); err != nil { + t.Fatal(err) + } + if err := req.Close(); err != nil { + t.Fatal(err) + } +} + +func TestRequestClone(t *testing.T) { + req, err := NewRequest(context.Background(), http.MethodPost, testURL) + if err != nil { + t.Fatal(err) + } + req.SkipBodyDownload() + if err := req.SetBody(shared.NopCloser(strings.NewReader("test")), "application/text"); err != nil { + t.Fatal(err) + } + clone := req.Clone(context.Background()) + var skip shared.BodyDownloadPolicyOpValues + if !clone.OperationValue(&skip) { + t.Fatal("missing operation value") + } + if !skip.Skip { + t.Fatal("wrong operation value") + } + if clone.body == nil { + t.Fatal("missing body") + } +} + +func TestNewRequestFail(t *testing.T) { + req, err := NewRequest(context.Background(), http.MethodOptions, "://test.contoso.com/") + if err == nil { + t.Fatal("unexpected nil error") + } + if req != nil { + t.Fatal("unexpected request") + } + req, err = NewRequest(context.Background(), http.MethodPatch, "/missing/the/host") + if err == nil { + t.Fatal("unexpected nil error") + } + if req != nil { + t.Fatal("unexpected request") + } + req, err = NewRequest(context.Background(), http.MethodPatch, "mailto://nobody.contoso.com") + if err == nil { + t.Fatal("unexpected nil error") + } + if req != nil { + t.Fatal("unexpected request") + } +} diff --git a/sdk/azcore/internal/pollers/loc/loc.go b/sdk/azcore/internal/pollers/loc/loc.go new file mode 100644 index 000000000000..357a3d7d5966 --- /dev/null +++ b/sdk/azcore/internal/pollers/loc/loc.go @@ -0,0 +1,80 @@ +//go:build go1.16 +// +build go1.16 + +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package loc + +import ( + "errors" + "fmt" + "net/http" + + "github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/pollers" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/shared" + "github.com/Azure/azure-sdk-for-go/sdk/internal/log" +) + +// Kind is the identifier of this type in a resume token. +const Kind = "Location" + +// Applicable returns true if the LRO is using Location. +func Applicable(resp *http.Response) bool { + return resp.Header.Get(shared.HeaderLocation) != "" +} + +// Poller is an LRO poller that uses the Location pattern. +type Poller struct { + Type string `json:"type"` + PollURL string `json:"pollURL"` + CurState int `json:"state"` +} + +// New creates a new Poller from the provided initial response. +func New(resp *http.Response, pollerID string) (*Poller, error) { + log.Write(log.LongRunningOperation, "Using Location poller.") + locURL := resp.Header.Get(shared.HeaderLocation) + if locURL == "" { + return nil, errors.New("response is missing Location header") + } + if !pollers.IsValidURL(locURL) { + return nil, fmt.Errorf("invalid polling URL %s", locURL) + } + return &Poller{ + Type: pollers.MakeID(pollerID, Kind), + PollURL: locURL, + CurState: resp.StatusCode, + }, nil +} + +func (p *Poller) URL() string { + return p.PollURL +} + +func (p *Poller) Done() bool { + return pollers.IsTerminalState(p.Status()) +} + +func (p *Poller) Update(resp *http.Response) error { + // if the endpoint returned a location header, update cached value + if loc := resp.Header.Get(shared.HeaderLocation); loc != "" { + p.PollURL = loc + } + p.CurState = resp.StatusCode + return nil +} + +func (*Poller) FinalGetURL() string { + return "" +} + +func (p *Poller) Status() string { + if p.CurState == http.StatusAccepted { + return pollers.StatusInProgress + } else if p.CurState > 199 && p.CurState < 300 { + // any 2xx other than a 202 indicates success + return pollers.StatusSucceeded + } + return pollers.StatusFailed +} diff --git a/sdk/azcore/internal/pollers/loc/loc_test.go b/sdk/azcore/internal/pollers/loc/loc_test.go new file mode 100644 index 000000000000..6fa70aef14d2 --- /dev/null +++ b/sdk/azcore/internal/pollers/loc/loc_test.go @@ -0,0 +1,136 @@ +//go:build go1.16 +// +build go1.16 + +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package loc + +import ( + "net/http" + "testing" + + "github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/pollers" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/shared" +) + +const ( + fakeLocationURL = "https://foo.bar.baz/status" + fakeLocationURL2 = "https://foo.bar.baz/status/other" +) + +func initialResponse() *http.Response { + return &http.Response{ + Header: http.Header{}, + StatusCode: http.StatusAccepted, + } +} + +func TestApplicable(t *testing.T) { + resp := &http.Response{ + Header: http.Header{}, + } + if Applicable(resp) { + t.Fatal("missing Location should not be applicable") + } + resp.Header.Set(shared.HeaderLocation, fakeLocationURL) + if !Applicable(resp) { + t.Fatal("having Location should be applicable") + } +} + +func TestNew(t *testing.T) { + resp := initialResponse() + resp.Header.Set(shared.HeaderLocation, fakeLocationURL) + poller, err := New(resp, "pollerID") + if err != nil { + t.Fatal(err) + } + if poller.Done() { + t.Fatal("unexpected done") + } + if u := poller.FinalGetURL(); u != "" { + t.Fatalf("unexpected final get URL %s", u) + } + if s := poller.Status(); s != pollers.StatusInProgress { + t.Fatalf("unexpected status %s", s) + } + if u := poller.URL(); u != fakeLocationURL { + t.Fatalf("unexpected polling URL %s", u) + } +} + +func TestNewFail(t *testing.T) { + resp := initialResponse() + poller, err := New(resp, "pollerID") + if err == nil { + t.Fatal("unexpected nil error") + } + if poller != nil { + t.Fatal("expected nil poller") + } + resp.Header.Set(shared.HeaderLocation, "/must/be/absolute") + poller, err = New(resp, "pollerID") + if err == nil { + t.Fatal("unexpected nil error") + } + if poller != nil { + t.Fatal("expected nil poller") + } +} + +func TestUpdateSucceeded(t *testing.T) { + resp := initialResponse() + resp.Header.Set(shared.HeaderLocation, fakeLocationURL) + poller, err := New(resp, "pollerID") + if err != nil { + t.Fatal(err) + } + resp.Header.Set(shared.HeaderLocation, fakeLocationURL2) + if err := poller.Update(resp); err != nil { + t.Fatal(err) + } + if poller.Done() { + t.Fatal("unexpected done") + } + if u := poller.URL(); u != fakeLocationURL2 { + t.Fatalf("unexpected polling URL %s", u) + } + if err := poller.Update(&http.Response{StatusCode: http.StatusOK}); err != nil { + t.Fatal(err) + } + if !poller.Done() { + t.Fatal("expected done") + } + if s := poller.Status(); s != pollers.StatusSucceeded { + t.Fatalf("unexpected status %s", s) + } +} + +func TestUpdateFailed(t *testing.T) { + resp := initialResponse() + resp.Header.Set(shared.HeaderLocation, fakeLocationURL) + poller, err := New(resp, "pollerID") + if err != nil { + t.Fatal(err) + } + resp.Header.Set(shared.HeaderLocation, fakeLocationURL2) + if err := poller.Update(resp); err != nil { + t.Fatal(err) + } + if poller.Done() { + t.Fatal("unexpected done") + } + if u := poller.URL(); u != fakeLocationURL2 { + t.Fatalf("unexpected polling URL %s", u) + } + if err := poller.Update(&http.Response{StatusCode: http.StatusConflict}); err != nil { + t.Fatal(err) + } + if !poller.Done() { + t.Fatal("expected done") + } + if s := poller.Status(); s != pollers.StatusFailed { + t.Fatalf("unexpected status %s", s) + } +} diff --git a/sdk/azcore/internal/pollers/op/op.go b/sdk/azcore/internal/pollers/op/op.go new file mode 100644 index 000000000000..730a85dfa795 --- /dev/null +++ b/sdk/azcore/internal/pollers/op/op.go @@ -0,0 +1,132 @@ +//go:build go1.16 +// +build go1.16 + +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package op + +import ( + "errors" + "fmt" + "net/http" + + "github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/pollers" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/shared" + "github.com/Azure/azure-sdk-for-go/sdk/internal/log" +) + +// Kind is the identifier of this type in a resume token. +const Kind = "Operation-Location" + +// Applicable returns true if the LRO is using Operation-Location. +func Applicable(resp *http.Response) bool { + return resp.Header.Get(shared.HeaderOperationLocation) != "" +} + +// Poller is an LRO poller that uses the Operation-Location pattern. +type Poller struct { + Type string `json:"type"` + PollURL string `json:"pollURL"` + LocURL string `json:"locURL"` + FinalGET string `json:"finalGET"` + CurState string `json:"state"` +} + +// New creates a new Poller from the provided initial response. +func New(resp *http.Response, pollerID string) (*Poller, error) { + log.Write(log.LongRunningOperation, "Using Operation-Location poller.") + opURL := resp.Header.Get(shared.HeaderOperationLocation) + if opURL == "" { + return nil, errors.New("response is missing Operation-Location header") + } + if !pollers.IsValidURL(opURL) { + return nil, fmt.Errorf("invalid Operation-Location URL %s", opURL) + } + locURL := resp.Header.Get(shared.HeaderLocation) + // Location header is optional + if locURL != "" && !pollers.IsValidURL(locURL) { + return nil, fmt.Errorf("invalid Location URL %s", locURL) + } + // default initial state to InProgress. if the + // service sent us a status then use that instead. + curState := pollers.StatusInProgress + status, err := getValue(resp, "status") + if err != nil && !errors.Is(err, shared.ErrNoBody) { + return nil, err + } + if status != "" { + curState = status + } + // calculate the tentative final GET URL. + // can change if we receive a resourceLocation. + // it's ok for it to be empty in some cases. + finalGET := "" + if resp.Request.Method == http.MethodPatch || resp.Request.Method == http.MethodPut { + finalGET = resp.Request.URL.String() + } else if resp.Request.Method == http.MethodPost && locURL != "" { + finalGET = locURL + } + return &Poller{ + Type: pollers.MakeID(pollerID, Kind), + PollURL: opURL, + LocURL: locURL, + FinalGET: finalGET, + CurState: curState, + }, nil +} + +func (p *Poller) URL() string { + return p.PollURL +} + +func (p *Poller) Done() bool { + return pollers.IsTerminalState(p.Status()) +} + +func (p *Poller) Update(resp *http.Response) error { + status, err := getValue(resp, "status") + if err != nil { + return err + } else if status == "" { + return errors.New("the response did not contain a status") + } + p.CurState = status + // if the endpoint returned an operation-location header, update cached value + if opLoc := resp.Header.Get(shared.HeaderOperationLocation); opLoc != "" { + p.PollURL = opLoc + } + // check for resourceLocation + resLoc, err := getValue(resp, "resourceLocation") + if err != nil && !errors.Is(err, shared.ErrNoBody) { + return err + } else if resLoc != "" { + p.FinalGET = resLoc + } + return nil +} + +func (p *Poller) FinalGetURL() string { + return p.FinalGET +} + +func (p *Poller) Status() string { + return p.CurState +} + +func getValue(resp *http.Response, val string) (string, error) { + jsonBody, err := shared.GetJSON(resp) + if err != nil { + return "", err + } + v, ok := jsonBody[val] + if !ok { + // it might be ok if the field doesn't exist, the caller must make that determination + return "", nil + } + vv, ok := v.(string) + if !ok { + return "", fmt.Errorf("the %s value %v was not in string format", val, v) + } + return vv, nil +} diff --git a/sdk/azcore/internal/pollers/op/op_test.go b/sdk/azcore/internal/pollers/op/op_test.go new file mode 100644 index 000000000000..55c3f2253ea4 --- /dev/null +++ b/sdk/azcore/internal/pollers/op/op_test.go @@ -0,0 +1,249 @@ +//go:build go1.16 +// +build go1.16 + +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package op + +import ( + "io" + "io/ioutil" + "net/http" + "strings" + "testing" + + "github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/pollers" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/shared" +) + +const ( + fakePollingURL = "https://foo.bar.baz/status" + fakePollingURL2 = "https://foo.bar.baz/status/updated" + fakeLocationURL = "https://foo.bar.baz/location" + fakeResourceURL = "https://foo.bar.baz/resource" +) + +func initialResponse(method string, body io.Reader) *http.Response { + req, err := http.NewRequest(method, fakeResourceURL, nil) + if err != nil { + panic(err) + } + return &http.Response{ + Body: ioutil.NopCloser(body), + Header: http.Header{}, + Request: req, + } +} + +func createResponse(body io.Reader) *http.Response { + return &http.Response{ + Body: ioutil.NopCloser(body), + Header: http.Header{}, + } +} + +func TestApplicable(t *testing.T) { + resp := &http.Response{ + Header: http.Header{}, + } + if Applicable(resp) { + t.Fatal("missing Operation-Location should not be applicable") + } + resp.Header.Set(shared.HeaderOperationLocation, fakePollingURL) + if !Applicable(resp) { + t.Fatal("having Operation-Location should be applicable") + } +} + +func TestNew(t *testing.T) { + resp := initialResponse(http.MethodPut, http.NoBody) + resp.Header.Set(shared.HeaderOperationLocation, fakePollingURL) + resp.Header.Set(shared.HeaderLocation, fakeLocationURL) + poller, err := New(resp, "pollerID") + if err != nil { + t.Fatal(err) + } + if poller.Done() { + t.Fatal("unexpected done") + } + if u := poller.FinalGetURL(); u != fakeResourceURL { + t.Fatalf("unexpected final get URL %s", u) + } + if s := poller.Status(); s != pollers.StatusInProgress { + t.Fatalf("unexpected status %s", s) + } + if u := poller.URL(); u != fakePollingURL { + t.Fatalf("unexpected URL %s", u) + } +} + +func TestNewWithInitialStatus(t *testing.T) { + resp := initialResponse(http.MethodPut, strings.NewReader(`{ "status": "Updating" }`)) + resp.Header.Set(shared.HeaderOperationLocation, fakePollingURL) + resp.Header.Set(shared.HeaderLocation, fakeLocationURL) + poller, err := New(resp, "pollerID") + if err != nil { + t.Fatal(err) + } + if poller.Done() { + t.Fatal("unexpected done") + } + if s := poller.Status(); s != "Updating" { + t.Fatalf("unexpected status %s", s) + } +} + +func TestNewWithPost(t *testing.T) { + resp := initialResponse(http.MethodPost, http.NoBody) + resp.Header.Set(shared.HeaderOperationLocation, fakePollingURL) + resp.Header.Set(shared.HeaderLocation, fakeLocationURL) + poller, err := New(resp, "pollerID") + if err != nil { + t.Fatal(err) + } + if poller.Done() { + t.Fatal("unexpected done") + } + if u := poller.FinalGetURL(); u != fakeLocationURL { + t.Fatalf("unexpected final get URL %s", u) + } +} + +func TestNewWithDelete(t *testing.T) { + resp := initialResponse(http.MethodDelete, http.NoBody) + resp.Header.Set(shared.HeaderOperationLocation, fakePollingURL) + resp.Header.Set(shared.HeaderLocation, fakeLocationURL) + poller, err := New(resp, "pollerID") + if err != nil { + t.Fatal(err) + } + if poller.Done() { + t.Fatal("unexpected done") + } + if u := poller.FinalGetURL(); u != "" { + t.Fatalf("unexpected final get URL %s", u) + } +} + +func TestNewFail(t *testing.T) { + resp := initialResponse(http.MethodPut, http.NoBody) + poller, err := New(resp, "pollerID") + if err == nil { + t.Fatal("unexpected nil error") + } + if poller != nil { + t.Fatal("expected nil poller") + } + resp.Header.Set(shared.HeaderOperationLocation, fakePollingURL) + resp.Header.Set(shared.HeaderLocation, "/must/be/absolute") + poller, err = New(resp, "pollerID") + if err == nil { + t.Fatal("unexpected nil error") + } + if poller != nil { + t.Fatal("expected nil poller") + } + resp.Header.Set(shared.HeaderOperationLocation, "/must/be/absolute") + poller, err = New(resp, "pollerID") + if err == nil { + t.Fatal("unexpected nil error") + } + if poller != nil { + t.Fatal("expected nil poller") + } +} + +func TestUpdateSucceeded(t *testing.T) { + resp := initialResponse(http.MethodPut, http.NoBody) + resp.Header.Set(shared.HeaderOperationLocation, fakePollingURL) + resp.Header.Set(shared.HeaderLocation, fakeLocationURL) + poller, err := New(resp, "pollerID") + if err != nil { + t.Fatal(err) + } + resp = createResponse(strings.NewReader(`{ "status": "Running" }`)) + resp.Header.Set(shared.HeaderOperationLocation, fakePollingURL2) + if err := poller.Update(resp); err != nil { + t.Fatal(err) + } + if poller.Done() { + t.Fatal("unexpected done") + } + if s := poller.Status(); s != "Running" { + t.Fatalf("unexpected status %s", s) + } + if u := poller.URL(); u != fakePollingURL2 { + t.Fatalf("unexpected URL %s", u) + } + resp = createResponse(strings.NewReader(`{ "status": "Succeeded" }`)) + if err := poller.Update(resp); err != nil { + t.Fatal(err) + } + if !poller.Done() { + t.Fatal("expected done") + } + if s := poller.Status(); s != pollers.StatusSucceeded { + t.Fatalf("unexpected status %s", s) + } +} + +func TestUpdateResourceLocation(t *testing.T) { + resp := initialResponse(http.MethodPut, http.NoBody) + resp.Header.Set(shared.HeaderOperationLocation, fakePollingURL) + resp.Header.Set(shared.HeaderLocation, fakeLocationURL) + poller, err := New(resp, "pollerID") + if err != nil { + t.Fatal(err) + } + resp = createResponse(strings.NewReader(`{ "status": "Succeeded", "resourceLocation": "https://foo.bar.baz/resource2" }`)) + if err := poller.Update(resp); err != nil { + t.Fatal(err) + } + if !poller.Done() { + t.Fatal("expected done") + } + if s := poller.Status(); s != pollers.StatusSucceeded { + t.Fatalf("unexpected status %s", s) + } + if u := poller.FinalGetURL(); u != "https://foo.bar.baz/resource2" { + t.Fatalf("unexpected final get url %s", u) + } +} + +func TestUpdateFailed(t *testing.T) { + resp := initialResponse(http.MethodPut, http.NoBody) + resp.Header.Set(shared.HeaderOperationLocation, fakePollingURL) + resp.Header.Set(shared.HeaderLocation, fakeLocationURL) + poller, err := New(resp, "pollerID") + if err != nil { + t.Fatal(err) + } + resp = createResponse(strings.NewReader(`{ "status": "Failed" }`)) + if err := poller.Update(resp); err != nil { + t.Fatal(err) + } + if !poller.Done() { + t.Fatal("expected done") + } + if s := poller.Status(); s != pollers.StatusFailed { + t.Fatalf("unexpected status %s", s) + } +} + +func TestUpdateMissingStatus(t *testing.T) { + resp := initialResponse(http.MethodPut, http.NoBody) + resp.Header.Set(shared.HeaderOperationLocation, fakePollingURL) + resp.Header.Set(shared.HeaderLocation, fakeLocationURL) + poller, err := New(resp, "pollerID") + if err != nil { + t.Fatal(err) + } + resp = createResponse(http.NoBody) + if err := poller.Update(resp); err == nil { + t.Fatal("unexpected nil error") + } + if poller.Done() { + t.Fatal("unexpected done") + } +} diff --git a/sdk/azcore/internal/pollers/poller.go b/sdk/azcore/internal/pollers/poller.go new file mode 100644 index 000000000000..aca2f3197b72 --- /dev/null +++ b/sdk/azcore/internal/pollers/poller.go @@ -0,0 +1,213 @@ +//go:build go1.16 +// +build go1.16 + +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package pollers + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "io/ioutil" + "net/http" + "reflect" + "time" + + "github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/pipeline" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/shared" + "github.com/Azure/azure-sdk-for-go/sdk/internal/log" +) + +// KindFromToken extracts the poller kind from the provided token. +// If the pollerID doesn't match what's in the token an error is returned. +func KindFromToken(pollerID, token string) (string, error) { + // unmarshal into JSON object to determine the poller type + obj := map[string]interface{}{} + err := json.Unmarshal([]byte(token), &obj) + if err != nil { + return "", err + } + t, ok := obj["type"] + if !ok { + return "", errors.New("missing type field") + } + tt, ok := t.(string) + if !ok { + return "", fmt.Errorf("invalid type format %T", t) + } + ttID, ttKind, err := DecodeID(tt) + if err != nil { + return "", err + } + // ensure poller types match + if ttID != pollerID { + return "", fmt.Errorf("cannot resume from this poller token. expected %s, received %s", pollerID, ttID) + } + return ttKind, nil +} + +// PollerType returns the concrete type of the poller (FOR TESTING PURPOSES). +func PollerType(p *Poller) reflect.Type { + return reflect.TypeOf(p.lro) +} + +// NewPoller creates a Poller from the specified input. +func NewPoller(lro Operation, resp *http.Response, pl pipeline.Pipeline, eu func(*http.Response) error) *Poller { + return &Poller{lro: lro, pl: pl, eu: eu, resp: resp} +} + +// Poller encapsulates state and logic for polling on long-running operations. +type Poller struct { + lro Operation + pl pipeline.Pipeline + eu func(*http.Response) error + resp *http.Response + err error +} + +// Done returns true if the LRO has reached a terminal state. +func (l *Poller) Done() bool { + if l.err != nil { + return true + } + return l.lro.Done() +} + +// Poll sends a polling request to the polling endpoint and returns the response or error. +func (l *Poller) Poll(ctx context.Context) (*http.Response, error) { + if l.Done() { + // the LRO has reached a terminal state, don't poll again + if l.resp != nil { + return l.resp, nil + } + return nil, l.err + } + req, err := pipeline.NewRequest(ctx, http.MethodGet, l.lro.URL()) + if err != nil { + return nil, err + } + resp, err := l.pl.Do(req) + if err != nil { + // don't update the poller for failed requests + return nil, err + } + defer resp.Body.Close() + if !StatusCodeValid(resp) { + // the LRO failed. unmarshall the error and update state + l.err = l.eu(resp) + l.resp = nil + return nil, l.err + } + if err = l.lro.Update(resp); err != nil { + return nil, err + } + l.resp = resp + log.Writef(log.LongRunningOperation, "Status %s", l.lro.Status()) + if Failed(l.lro.Status()) { + l.err = l.eu(resp) + l.resp = nil + return nil, l.err + } + return l.resp, nil +} + +// ResumeToken returns a token string that can be used to resume a poller that has not yet reached a terminal state. +func (l *Poller) ResumeToken() (string, error) { + if l.Done() { + return "", errors.New("cannot create a ResumeToken from a poller in a terminal state") + } + b, err := json.Marshal(l.lro) + if err != nil { + return "", err + } + return string(b), nil +} + +// FinalResponse will perform a final GET request and return the final HTTP response for the polling +// operation and unmarshall the content of the payload into the respType interface that is provided. +func (l *Poller) FinalResponse(ctx context.Context, respType interface{}) (*http.Response, error) { + if !l.Done() { + return nil, errors.New("cannot return a final response from a poller in a non-terminal state") + } + // update l.resp with the content from final GET if applicable + if u := l.lro.FinalGetURL(); u != "" { + log.Write(log.LongRunningOperation, "Performing final GET.") + req, err := pipeline.NewRequest(ctx, http.MethodGet, u) + if err != nil { + return nil, err + } + resp, err := l.pl.Do(req) + if err != nil { + return nil, err + } + if !StatusCodeValid(resp) { + return nil, l.eu(resp) + } + l.resp = resp + } + // if there's nothing to unmarshall into or no response body just return the final response + if respType == nil { + return l.resp, nil + } else if l.resp.StatusCode == http.StatusNoContent || l.resp.ContentLength == 0 { + log.Write(log.LongRunningOperation, "final response specifies a response type but no payload was received") + return l.resp, nil + } + body, err := ioutil.ReadAll(l.resp.Body) + l.resp.Body.Close() + if err != nil { + return nil, err + } + if err = json.Unmarshal(body, respType); err != nil { + return nil, err + } + return l.resp, nil +} + +// PollUntilDone will handle the entire span of the polling operation until a terminal state is reached, +// then return the final HTTP response for the polling operation and unmarshal the content of the payload +// into the respType interface that is provided. +// freq - the time to wait between polling intervals if the endpoint doesn't send a Retry-After header. +// A good starting value is 30 seconds. Note that some resources might benefit from a different value. +func (l *Poller) PollUntilDone(ctx context.Context, freq time.Duration, respType interface{}) (*http.Response, error) { + start := time.Now() + logPollUntilDoneExit := func(v interface{}) { + log.Writef(log.LongRunningOperation, "END PollUntilDone() for %T: %v, total time: %s", l.lro, v, time.Since(start)) + } + log.Writef(log.LongRunningOperation, "BEGIN PollUntilDone() for %T", l.lro) + if l.resp != nil { + // initial check for a retry-after header existing on the initial response + if retryAfter := shared.RetryAfter(l.resp); retryAfter > 0 { + log.Writef(log.LongRunningOperation, "initial Retry-After delay for %s", retryAfter.String()) + if err := shared.Delay(ctx, retryAfter); err != nil { + logPollUntilDoneExit(err) + return nil, err + } + } + } + // begin polling the endpoint until a terminal state is reached + for { + resp, err := l.Poll(ctx) + if err != nil { + logPollUntilDoneExit(err) + return nil, err + } + if l.Done() { + logPollUntilDoneExit(l.lro.Status()) + return l.FinalResponse(ctx, respType) + } + d := freq + if retryAfter := shared.RetryAfter(resp); retryAfter > 0 { + log.Writef(log.LongRunningOperation, "Retry-After delay for %s", retryAfter.String()) + d = retryAfter + } else { + log.Writef(log.LongRunningOperation, "delay for %s", d.String()) + } + if err = shared.Delay(ctx, d); err != nil { + logPollUntilDoneExit(err) + return nil, err + } + } +} diff --git a/sdk/azcore/internal/pollers/poller_test.go b/sdk/azcore/internal/pollers/poller_test.go new file mode 100644 index 000000000000..01dc426cab51 --- /dev/null +++ b/sdk/azcore/internal/pollers/poller_test.go @@ -0,0 +1,256 @@ +//go:build go1.16 +// +build go1.16 + +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package pollers + +import ( + "context" + "errors" + "net/http" + "testing" + "time" + + "github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/pipeline" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/shared" + "github.com/Azure/azure-sdk-for-go/sdk/internal/mock" +) + +func TestKindFromToken(t *testing.T) { + const tk = `{ "type": "pollerID;kind" }` + k, err := KindFromToken("pollerID", tk) + if err != nil { + t.Fatal(err) + } + if k != "kind" { + t.Fatalf("unexpected kind %s", k) + } + k, err = KindFromToken("mismatched", tk) + if err == nil { + t.Fatal("unexpected nil error") + } + if k != "" { + t.Fatal("expected empty kind") + } +} + +func TestKindFromTokenInvalid(t *testing.T) { + const tk1 = `{ "missing": "type" }` + k, err := KindFromToken("mismatched", tk1) + if err == nil { + t.Fatal("unexpected nil error") + } + if k != "" { + t.Fatal("expected empty kind") + } + const tk2 = `{ "type": false }` + k, err = KindFromToken("mismatched", tk2) + if err == nil { + t.Fatal("unexpected nil error") + } + if k != "" { + t.Fatal("expected empty kind") + } + const tk3 = `{ "type": "pollerID;kind;extra" }` + k, err = KindFromToken("mismatched", tk3) + if err == nil { + t.Fatal("unexpected nil error") + } + if k != "" { + t.Fatal("expected empty kind") + } +} + +// simple status code-based poller +type fakePoller struct { + Ep string + Fg string + Code int +} + +func (f *fakePoller) Done() bool { + return f.Code == http.StatusOK || f.Code == http.StatusNoContent +} + +func (f *fakePoller) Update(resp *http.Response) error { + f.Code = resp.StatusCode + return nil +} + +func (f *fakePoller) FinalGetURL() string { + return f.Fg +} + +func (f *fakePoller) URL() string { + return f.Ep +} + +func (f *fakePoller) Status() string { + switch f.Code { + case http.StatusAccepted: + return StatusInProgress + case http.StatusOK, http.StatusNoContent: + return StatusSucceeded + case http.StatusCreated: + return StatusCanceled + default: + return StatusFailed + } +} + +func TestNewPoller(t *testing.T) { + srv, close := mock.NewServer() + srv.AppendResponse(mock.WithStatusCode(http.StatusAccepted)) + srv.AppendResponse(mock.WithStatusCode(http.StatusNoContent)) // terminal + defer close() + pl := pipeline.NewPipeline(srv) + firstResp := &http.Response{ + StatusCode: http.StatusAccepted, + Header: http.Header{}, + } + firstResp.Header.Set(shared.HeaderRetryAfter, "1") + p := NewPoller(&fakePoller{Ep: srv.URL()}, firstResp, pl, func(*http.Response) error { + return errors.New("failed") + }) + if p.Done() { + t.Fatal("unexpected done") + } + resp, err := p.FinalResponse(context.Background(), nil) + if err == nil { + t.Fatal("unexpected nil error") + } + if resp != nil { + t.Fatal("expected nil response") + } + tk, err := p.ResumeToken() + if err != nil { + t.Fatal(err) + } + if tk == "" { + t.Fatal("unexpected empty resume token") + } + resp, err = p.PollUntilDone(context.Background(), 1*time.Millisecond, nil) + if err != nil { + t.Fatal(err) + } + if resp.StatusCode != http.StatusNoContent { + t.Fatalf("unexpected status code %d", resp.StatusCode) + } + tk, err = p.ResumeToken() + if err == nil { + t.Fatal("unexpected nil error") + } + if tk != "" { + t.Fatal("expected empty resume token") + } +} + +func TestNewPollerWithFinalGET(t *testing.T) { + srv, close := mock.NewServer() + srv.AppendResponse(mock.WithStatusCode(http.StatusAccepted), mock.WithHeader(shared.HeaderRetryAfter, "1")) + srv.AppendResponse(mock.WithStatusCode(http.StatusOK)) // terminal + srv.AppendResponse(mock.WithStatusCode(http.StatusOK), mock.WithBody([]byte(`{ "shape": "round" }`))) // final GET + defer close() + pl := pipeline.NewPipeline(srv) + firstResp := &http.Response{ + StatusCode: http.StatusAccepted, + } + p := NewPoller(&fakePoller{Ep: srv.URL(), Fg: srv.URL()}, firstResp, pl, func(*http.Response) error { + return errors.New("failed") + }) + if p.Done() { + t.Fatal("unexpected done") + } + type widget struct { + Shape string `json:"shape"` + } + var w widget + resp, err := p.PollUntilDone(context.Background(), 1*time.Millisecond, &w) + if err != nil { + t.Fatal(err) + } + if resp.StatusCode != http.StatusOK { + t.Fatalf("unexpected status code %d", resp.StatusCode) + } + if w.Shape != "round" { + t.Fatalf("unexpected result %s", w.Shape) + } + resp, err = p.Poll(context.Background()) + if err != nil { + t.Fatal(err) + } + if resp.StatusCode != http.StatusOK { + t.Fatalf("unexpected status code %d", resp.StatusCode) + } +} + +func TestNewPollerFail1(t *testing.T) { + srv, close := mock.NewServer() + srv.AppendResponse(mock.WithStatusCode(http.StatusAccepted)) + srv.AppendResponse(mock.WithStatusCode(http.StatusConflict)) // terminal + defer close() + pl := pipeline.NewPipeline(srv) + firstResp := &http.Response{ + StatusCode: http.StatusAccepted, + } + p := NewPoller(&fakePoller{Ep: srv.URL()}, firstResp, pl, func(*http.Response) error { + return errors.New("failed") + }) + resp, err := p.PollUntilDone(context.Background(), 1*time.Millisecond, nil) + if err == nil { + t.Fatal("unexpected nil error") + } else if s := err.Error(); s != "failed" { + t.Fatalf("unexpected error %s", s) + } + if resp != nil { + t.Fatal("expected nil response") + } +} + +func TestNewPollerFail2(t *testing.T) { + srv, close := mock.NewServer() + srv.AppendResponse(mock.WithStatusCode(http.StatusAccepted)) + srv.AppendResponse(mock.WithStatusCode(http.StatusCreated)) // terminal + defer close() + pl := pipeline.NewPipeline(srv) + firstResp := &http.Response{ + StatusCode: http.StatusAccepted, + } + p := NewPoller(&fakePoller{Ep: srv.URL()}, firstResp, pl, func(*http.Response) error { + return errors.New("failed") + }) + resp, err := p.PollUntilDone(context.Background(), 1*time.Millisecond, nil) + if err == nil { + t.Fatal("unexpected nil error") + } else if s := err.Error(); s != "failed" { + t.Fatalf("unexpected error %s", s) + } + if resp != nil { + t.Fatal("expected nil response") + } +} + +func TestNewPollerError(t *testing.T) { + srv, close := mock.NewServer() + srv.AppendResponse(mock.WithStatusCode(http.StatusAccepted)) + srv.AppendError(errors.New("fatal")) + defer close() + pl := pipeline.NewPipeline(srv) + firstResp := &http.Response{ + StatusCode: http.StatusAccepted, + } + p := NewPoller(&fakePoller{Ep: srv.URL()}, firstResp, pl, func(*http.Response) error { + return errors.New("failed") + }) + resp, err := p.PollUntilDone(context.Background(), 1*time.Millisecond, nil) + if err == nil { + t.Fatal("unexpected nil error") + } else if s := err.Error(); s != "fatal" { + t.Fatalf("unexpected error %s", s) + } + if resp != nil { + t.Fatal("expected nil response") + } +} diff --git a/sdk/azcore/internal/pollers/util.go b/sdk/azcore/internal/pollers/util.go new file mode 100644 index 000000000000..dca70b5a596b --- /dev/null +++ b/sdk/azcore/internal/pollers/util.go @@ -0,0 +1,99 @@ +//go:build go1.16 +// +build go1.16 + +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package pollers + +import ( + "fmt" + "net/http" + "net/url" + "strings" + + "github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/shared" +) + +const ( + StatusSucceeded = "Succeeded" + StatusCanceled = "Canceled" + StatusFailed = "Failed" + StatusInProgress = "InProgress" +) + +// Operation abstracts the differences between concrete poller types. +type Operation interface { + Done() bool + Update(resp *http.Response) error + FinalGetURL() string + URL() string + Status() string +} + +// IsTerminalState returns true if the LRO's state is terminal. +func IsTerminalState(s string) bool { + return strings.EqualFold(s, StatusSucceeded) || strings.EqualFold(s, StatusFailed) || strings.EqualFold(s, StatusCanceled) +} + +// Failed returns true if the LRO's state is terminal failure. +func Failed(s string) bool { + return strings.EqualFold(s, StatusFailed) || strings.EqualFold(s, StatusCanceled) +} + +// returns true if the LRO response contains a valid HTTP status code +func StatusCodeValid(resp *http.Response) bool { + return shared.HasStatusCode(resp, http.StatusOK, http.StatusAccepted, http.StatusCreated, http.StatusNoContent) +} + +// IsValidURL verifies that the URL is valid and absolute. +func IsValidURL(s string) bool { + u, err := url.Parse(s) + return err == nil && u.IsAbs() +} + +const idSeparator = ";" + +// MakeID returns the poller ID from the provided values. +func MakeID(pollerID string, kind string) string { + return fmt.Sprintf("%s%s%s", pollerID, idSeparator, kind) +} + +// DecodeID decodes the poller ID, returning [pollerID, kind] or an error. +func DecodeID(tk string) (string, string, error) { + raw := strings.Split(tk, idSeparator) + // strings.Split will include any/all whitespace strings, we want to omit those + parts := []string{} + for _, r := range raw { + if s := strings.TrimSpace(r); s != "" { + parts = append(parts, s) + } + } + if len(parts) != 2 { + return "", "", fmt.Errorf("invalid token %s", tk) + } + return parts[0], parts[1], nil +} + +// used if the operation synchronously completed +type NopPoller struct{} + +func (*NopPoller) URL() string { + return "" +} + +func (*NopPoller) Done() bool { + return true +} + +func (*NopPoller) Update(*http.Response) error { + return nil +} + +func (*NopPoller) FinalGetURL() string { + return "" +} + +func (*NopPoller) Status() string { + return StatusSucceeded +} diff --git a/sdk/azcore/internal/pollers/util_test.go b/sdk/azcore/internal/pollers/util_test.go new file mode 100644 index 000000000000..04932bfcdf73 --- /dev/null +++ b/sdk/azcore/internal/pollers/util_test.go @@ -0,0 +1,169 @@ +//go:build go1.16 +// +build go1.16 + +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package pollers + +import ( + "net/http" + "strings" + "testing" +) + +func TestIsTerminalState(t *testing.T) { + if IsTerminalState("Updating") { + t.Fatal("Updating is not a terminal state") + } + if !IsTerminalState("Succeeded") { + t.Fatal("Succeeded is a terminal state") + } + if !IsTerminalState("failed") { + t.Fatal("failed is a terminal state") + } + if !IsTerminalState("canceled") { + t.Fatal("canceled is a terminal state") + } +} + +func TestStatusCodeValid(t *testing.T) { + if !StatusCodeValid(&http.Response{StatusCode: http.StatusOK}) { + t.Fatal("unexpected valid code") + } + if !StatusCodeValid(&http.Response{StatusCode: http.StatusAccepted}) { + t.Fatal("unexpected valid code") + } + if !StatusCodeValid(&http.Response{StatusCode: http.StatusCreated}) { + t.Fatal("unexpected valid code") + } + if !StatusCodeValid(&http.Response{StatusCode: http.StatusNoContent}) { + t.Fatal("unexpected valid code") + } + if StatusCodeValid(&http.Response{StatusCode: http.StatusPartialContent}) { + t.Fatal("unexpected valid code") + } + if StatusCodeValid(&http.Response{StatusCode: http.StatusBadRequest}) { + t.Fatal("unexpected valid code") + } + if StatusCodeValid(&http.Response{StatusCode: http.StatusInternalServerError}) { + t.Fatal("unexpected valid code") + } +} + +func TestMakeID(t *testing.T) { + const ( + pollerID = "pollerID" + kind = "kind" + ) + id := MakeID(pollerID, kind) + parts := strings.Split(id, idSeparator) + if l := len(parts); l != 2 { + t.Fatalf("unexpected length %d", l) + } + if p := parts[0]; p != pollerID { + t.Fatalf("unexpected poller ID %s", p) + } + if p := parts[1]; p != kind { + t.Fatalf("unexpected poller kind %s", p) + } +} + +func TestDecodeID(t *testing.T) { + _, _, err := DecodeID("") + if err == nil { + t.Fatal("unexpected nil error") + } + _, _, err = DecodeID("invalid_token") + if err == nil { + t.Fatal("unexpected nil error") + } + _, _, err = DecodeID("invalid_token;") + if err == nil { + t.Fatal("unexpected nil error") + } + _, _, err = DecodeID(" ;invalid_token") + if err == nil { + t.Fatal("unexpected nil error") + } + _, _, err = DecodeID("invalid;token;too") + if err == nil { + t.Fatal("unexpected nil error") + } + id, kind, err := DecodeID("pollerID;kind") + if err != nil { + t.Fatal(err) + } + if id != "pollerID" { + t.Fatalf("unexpected ID %s", id) + } + if kind != "kind" { + t.Fatalf("unexpected kin %s", kind) + } +} + +func TestIsValidURL(t *testing.T) { + if IsValidURL("/foo") { + t.Fatal("unexpected valid URL") + } + if !IsValidURL("https://foo.bar/baz") { + t.Fatal("expected valid URL") + } +} + +func TestFailed(t *testing.T) { + if Failed("Succeeded") || Failed("Updating") { + t.Fatal("unexpected failure") + } + if !Failed("failed") { + t.Fatal("expected failure") + } +} + +func TestNopPoller(t *testing.T) { + np := NopPoller{} + if !np.Done() { + t.Fatal("expected done") + } + if np.FinalGetURL() != "" { + t.Fatal("expected empty final get URL") + } + if np.Status() != StatusSucceeded { + t.Fatal("expected Succeeded") + } + if np.URL() != "" { + t.Fatal("expected empty URL") + } + if err := np.Update(nil); err != nil { + t.Fatal(err) + } +} + +/*func TestNewPollerNop(t *testing.T) { + srv, close := mock.NewServer() + defer close() + resp := initialResponse(http.MethodPost, srv.URL(), strings.NewReader(successResp)) + resp.StatusCode = http.StatusOK + poller, err := NewPoller("pollerID", "", resp, getPipeline(srv), handleError) + if err != nil { + t.Fatal(err) + } + if _, ok := poller.lro.(*nopPoller); !ok { + t.Fatalf("unexpected poller type %T", poller.lro) + } + tk, err := poller.ResumeToken() + if err == nil { + t.Fatal("unexpected nil error") + } + if tk != "" { + t.Fatal("expected empty token") + } + var result mockType + _, err = poller.PollUntilDone(context.Background(), 10*time.Millisecond, &result) + if err != nil { + t.Fatal(err) + } + if v := *result.Field; v != "value" { + t.Fatalf("unexpected value %s", v) + } +}*/ diff --git a/sdk/azcore/internal/shared/constants.go b/sdk/azcore/internal/shared/constants.go new file mode 100644 index 000000000000..d2862d8efa3f --- /dev/null +++ b/sdk/azcore/internal/shared/constants.go @@ -0,0 +1,34 @@ +//go:build go1.16 +// +build go1.16 + +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package shared + +const ( + ContentTypeAppJSON = "application/json" + ContentTypeAppXML = "application/xml" +) + +const ( + HeaderAzureAsync = "Azure-AsyncOperation" + HeaderContentLength = "Content-Length" + HeaderContentType = "Content-Type" + HeaderLocation = "Location" + HeaderOperationLocation = "Operation-Location" + HeaderRetryAfter = "Retry-After" + HeaderUserAgent = "User-Agent" +) + +const ( + DefaultMaxRetries = 3 +) + +const ( + // Module is the name of the calling module used in telemetry data. + Module = "azcore" + + // Version is the semantic version (see http://semver.org) of this module. + Version = "v0.19.0" +) diff --git a/sdk/azcore/internal/shared/shared.go b/sdk/azcore/internal/shared/shared.go new file mode 100644 index 000000000000..c275221f8886 --- /dev/null +++ b/sdk/azcore/internal/shared/shared.go @@ -0,0 +1,148 @@ +//go:build go1.16 +// +build go1.16 + +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package shared + +import ( + "bytes" + "context" + "encoding/json" + "errors" + "io" + "io/ioutil" + "net/http" + "strconv" + "time" +) + +// CtxWithHTTPHeaderKey is used as a context key for adding/retrieving http.Header. +type CtxWithHTTPHeaderKey struct{} + +// CtxWithRetryOptionsKey is used as a context key for adding/retrieving RetryOptions. +type CtxWithRetryOptionsKey struct{} + +type nopCloser struct { + io.ReadSeeker +} + +func (n nopCloser) Close() error { + return nil +} + +// NopCloser returns a ReadSeekCloser with a no-op close method wrapping the provided io.ReadSeeker. +func NopCloser(rs io.ReadSeeker) io.ReadSeekCloser { + return nopCloser{rs} +} + +// BodyDownloadPolicyOpValues is the struct containing the per-operation values +type BodyDownloadPolicyOpValues struct { + Skip bool +} + +func NewResponseError(inner error, resp *http.Response) error { + return &ResponseError{inner: inner, resp: resp} +} + +type ResponseError struct { + inner error + resp *http.Response +} + +// Error implements the error interface for type ResponseError. +func (e *ResponseError) Error() string { + return e.inner.Error() +} + +// Unwrap returns the inner error. +func (e *ResponseError) Unwrap() error { + return e.inner +} + +// RawResponse returns the HTTP response associated with this error. +func (e *ResponseError) RawResponse() *http.Response { + return e.resp +} + +// NonRetriable indicates this error is non-transient. +func (e *ResponseError) NonRetriable() { + // marker method +} + +// Delay waits for the duration to elapse or the context to be cancelled. +func Delay(ctx context.Context, delay time.Duration) error { + select { + case <-time.After(delay): + return nil + case <-ctx.Done(): + return ctx.Err() + } +} + +// ErrNoBody is returned if the response didn't contain a body. +var ErrNoBody = errors.New("the response did not contain a body") + +// GetJSON reads the response body into a raw JSON object. +// It returns ErrNoBody if there was no content. +func GetJSON(resp *http.Response) (map[string]interface{}, error) { + body, err := ioutil.ReadAll(resp.Body) + defer resp.Body.Close() + if err != nil { + return nil, err + } + if len(body) == 0 { + return nil, ErrNoBody + } + // put the body back so it's available to others + resp.Body = ioutil.NopCloser(bytes.NewReader(body)) + // unmarshall the body to get the value + var jsonBody map[string]interface{} + if err = json.Unmarshal(body, &jsonBody); err != nil { + return nil, err + } + return jsonBody, nil +} + +// RetryAfter returns non-zero if the response contains a Retry-After header value. +func RetryAfter(resp *http.Response) time.Duration { + if resp == nil { + return 0 + } + ra := resp.Header.Get(HeaderRetryAfter) + if ra == "" { + return 0 + } + // retry-after values are expressed in either number of + // seconds or an HTTP-date indicating when to try again + if retryAfter, _ := strconv.Atoi(ra); retryAfter > 0 { + return time.Duration(retryAfter) * time.Second + } else if t, err := time.Parse(time.RFC1123, ra); err == nil { + return time.Until(t) + } + return 0 +} + +// HasStatusCode returns true if the Response's status code is one of the specified values. +func HasStatusCode(resp *http.Response, statusCodes ...int) bool { + if resp == nil { + return false + } + for _, sc := range statusCodes { + if resp.StatusCode == sc { + return true + } + } + return false +} + +const defaultScope = "/.default" + +// EndpointToScope converts the provided URL endpoint to its default scope. +func EndpointToScope(endpoint string) string { + if endpoint[len(endpoint)-1] != '/' { + endpoint += "/" + } + return endpoint + defaultScope +} diff --git a/sdk/azcore/internal/shared/shared_test.go b/sdk/azcore/internal/shared/shared_test.go new file mode 100644 index 000000000000..225b89cc6bc4 --- /dev/null +++ b/sdk/azcore/internal/shared/shared_test.go @@ -0,0 +1,133 @@ +//go:build go1.16 +// +build go1.16 + +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package shared + +import ( + "context" + "errors" + "io/ioutil" + "net/http" + "strings" + "testing" + "time" +) + +func TestNopCloser(t *testing.T) { + nc := NopCloser(strings.NewReader("foo")) + if err := nc.Close(); err != nil { + t.Fatal(err) + } +} + +type testError struct { + m string +} + +func (t testError) Error() string { + return t.m +} + +func TestNewResponseError(t *testing.T) { + err := NewResponseError(testError{m: "crash"}, &http.Response{StatusCode: http.StatusInternalServerError}) + if s := err.Error(); s != "crash" { + t.Fatalf("unexpected error %s", s) + } + re, ok := err.(*ResponseError) + if !ok { + t.Fatalf("unexpected error type %T", err) + } + re.NonRetriable() // nop + if c := re.RawResponse().StatusCode; c != http.StatusInternalServerError { + t.Fatalf("unexpected status code %d", c) + } + var te testError + if !errors.As(err, &te) { + t.Fatal("unwrap failed") + } +} + +func TestDelay(t *testing.T) { + if err := Delay(context.Background(), 5*time.Millisecond); err != nil { + t.Fatal(err) + } + ctx, cancel := context.WithCancel(context.Background()) + cancel() + if err := Delay(ctx, 5*time.Minute); err == nil { + t.Fatal("unexpected nil error") + } +} + +func TestGetJSON(t *testing.T) { + j, err := GetJSON(&http.Response{Body: http.NoBody}) + if !errors.Is(err, ErrNoBody) { + t.Fatal(err) + } + if j != nil { + t.Fatal("expected nil json") + } + j, err = GetJSON(&http.Response{Body: ioutil.NopCloser(strings.NewReader(`{ "foo": "bar" }`))}) + if err != nil { + t.Fatal(err) + } + if v := j["foo"]; v != "bar" { + t.Fatalf("unexpected value %s", v) + } +} + +func TestRetryAfter(t *testing.T) { + if RetryAfter(nil) != 0 { + t.Fatal("expected zero duration") + } + resp := &http.Response{ + Header: http.Header{}, + } + if d := RetryAfter(resp); d > 0 { + t.Fatalf("unexpected retry-after value %d", d) + } + resp.Header.Set(HeaderRetryAfter, "300") + d := RetryAfter(resp) + if d <= 0 { + t.Fatal("expected retry-after value from seconds") + } + if d != 300*time.Second { + t.Fatalf("expected 300 seconds, got %d", d/time.Second) + } + atDate := time.Now().Add(600 * time.Second) + resp.Header.Set(HeaderRetryAfter, atDate.Format(time.RFC1123)) + d = RetryAfter(resp) + if d <= 0 { + t.Fatal("expected retry-after value from date") + } + // d will not be exactly 600 seconds but it will be close + if s := d / time.Second; s < 598 || s > 602 { + t.Fatalf("expected ~600 seconds, got %d", s) + } +} + +func TestHasStatusCode(t *testing.T) { + if HasStatusCode(nil, http.StatusAccepted) { + t.Fatal("unexpected success") + } + if HasStatusCode(&http.Response{}) { + t.Fatal("unexpected success") + } + if HasStatusCode(&http.Response{StatusCode: http.StatusBadGateway}, http.StatusBadRequest) { + t.Fatal("unexpected success") + } + if !HasStatusCode(&http.Response{StatusCode: http.StatusOK}, http.StatusAccepted, http.StatusOK, http.StatusNoContent) { + t.Fatal("unexpected failure") + } +} + +func TestEndpointToScope(t *testing.T) { + if s := EndpointToScope("https://management.microsoftazure.de/"); s != "https://management.microsoftazure.de//.default" { + t.Fatalf("unexpected scope %s", s) + } + if s := EndpointToScope("https://management.usgovcloudapi.net"); s != "https://management.usgovcloudapi.net//.default" { + t.Fatalf("unexpected scope %s", s) + } +} diff --git a/sdk/azcore/policy/policy.go b/sdk/azcore/policy/policy.go new file mode 100644 index 000000000000..77722339f44f --- /dev/null +++ b/sdk/azcore/policy/policy.go @@ -0,0 +1,96 @@ +//go:build go1.16 +// +build go1.16 + +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package policy + +import ( + "context" + "net/http" + "time" + + "github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/pipeline" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/shared" +) + +// Policy represents an extensibility point for the Pipeline that can mutate the specified +// Request and react to the received Response. +type Policy = pipeline.Policy + +// Transporter represents an HTTP pipeline transport used to send HTTP requests and receive responses. +type Transporter = pipeline.Transporter + +// Request is an abstraction over the creation of an HTTP request as it passes through the pipeline. +// Don't use this type directly, use runtime.NewRequest() instead. +type Request = pipeline.Request + +// LogOptions configures the logging policy's behavior. +type LogOptions struct { + // IncludeBody indicates if request and response bodies should be included in logging. + // The default value is false. + // NOTE: enabling this can lead to disclosure of sensitive information, use with care. + IncludeBody bool +} + +// RetryOptions configures the retry policy's behavior. +// Call NewRetryOptions() to create an instance with default values. +type RetryOptions struct { + // MaxRetries specifies the maximum number of attempts a failed operation will be retried + // before producing an error. + // The default value is three. A value less than zero means one try and no retries. + MaxRetries int32 + + // TryTimeout indicates the maximum time allowed for any single try of an HTTP request. + // This is disabled by default. Specify a value greater than zero to enable. + // NOTE: Setting this to a small value might cause premature HTTP request time-outs. + TryTimeout time.Duration + + // RetryDelay specifies the initial amount of delay to use before retrying an operation. + // The delay increases exponentially with each retry up to the maximum specified by MaxRetryDelay. + // The default value is four seconds. A value less than zero means no delay between retries. + RetryDelay time.Duration + + // MaxRetryDelay specifies the maximum delay allowed before retrying an operation. + // Typically the value is greater than or equal to the value specified in RetryDelay. + // The default Value is 120 seconds. A value less than zero means there is no cap. + MaxRetryDelay time.Duration + + // StatusCodes specifies the HTTP status codes that indicate the operation should be retried. + // The default value is the status codes in StatusCodesForRetry. + // Specifying an empty slice will cause retries to happen only for transport errors. + StatusCodes []int +} + +// TelemetryOptions configures the telemetry policy's behavior. +type TelemetryOptions struct { + // ApplicationID is an application-specific identification string used in telemetry. + // It has a maximum length of 24 characters and must not contain any spaces. + ApplicationID string + + // Disabled will prevent the addition of any telemetry data to the User-Agent. + Disabled bool +} + +// TokenRequestOptions contain specific parameter that may be used by credentials types when attempting to get a token. +type TokenRequestOptions struct { + // Scopes contains the list of permission scopes required for the token. + Scopes []string + // TenantID contains the tenant ID to use in a multi-tenant authentication scenario, if TenantID is set + // it will override the tenant ID that was added at credential creation time. + TenantID string +} + +// WithHTTPHeader adds the specified http.Header to the parent context. +// Use this to specify custom HTTP headers at the API-call level. +// Any overlapping headers will have their values replaced with the values specified here. +func WithHTTPHeader(parent context.Context, header http.Header) context.Context { + return context.WithValue(parent, shared.CtxWithHTTPHeaderKey{}, header) +} + +// WithRetryOptions adds the specified RetryOptions to the parent context. +// Use this to specify custom RetryOptions at the API-call level. +func WithRetryOptions(parent context.Context, options RetryOptions) context.Context { + return context.WithValue(parent, shared.CtxWithRetryOptionsKey{}, options) +} diff --git a/sdk/azcore/policy/policy_test.go b/sdk/azcore/policy/policy_test.go new file mode 100644 index 000000000000..65bcf125cbe4 --- /dev/null +++ b/sdk/azcore/policy/policy_test.go @@ -0,0 +1,54 @@ +//go:build go1.16 +// +build go1.16 + +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package policy + +import ( + "context" + "math" + "net/http" + "testing" + + "github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/shared" +) + +func TestWithHTTPHeader(t *testing.T) { + const ( + key = "some" + val = "thing" + ) + input := http.Header{} + input.Set(key, val) + ctx := WithHTTPHeader(context.Background(), input) + if ctx == nil { + t.Fatal("nil context") + } + raw := ctx.Value(shared.CtxWithHTTPHeaderKey{}) + header, ok := raw.(http.Header) + if !ok { + t.Fatalf("unexpected type %T", raw) + } + if v := header.Get(key); v != val { + t.Fatalf("unexpected value %s", v) + } +} + +func TestWithRetryOptions(t *testing.T) { + ctx := WithRetryOptions(context.Background(), RetryOptions{ + MaxRetries: math.MaxInt32, + }) + if ctx == nil { + t.Fatal("nil context") + } + raw := ctx.Value(shared.CtxWithRetryOptionsKey{}) + opts, ok := raw.(RetryOptions) + if !ok { + t.Fatalf("unexpected type %T", raw) + } + if opts.MaxRetries != math.MaxInt32 { + t.Fatalf("unexpected value %d", opts.MaxRetries) + } +} diff --git a/sdk/azcore/policy_anonymous_credential.go b/sdk/azcore/policy_anonymous_credential.go index 496b976f561c..45f3e993b118 100644 --- a/sdk/azcore/policy_anonymous_credential.go +++ b/sdk/azcore/policy_anonymous_credential.go @@ -6,13 +6,19 @@ package azcore -import "net/http" +import ( + "net/http" -func anonCredAuthPolicyFunc(AuthenticationOptions) Policy { - return policyFunc(anonCredPolicyFunc) + "github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/pipeline" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/policy" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/runtime" +) + +func anonCredAuthPolicyFunc(runtime.AuthenticationOptions) policy.Policy { + return pipeline.PolicyFunc(anonCredPolicyFunc) } -func anonCredPolicyFunc(req *Request) (*http.Response, error) { +func anonCredPolicyFunc(req *policy.Request) (*http.Response, error) { return req.Next() } diff --git a/sdk/azcore/policy_anonymous_credential_test.go b/sdk/azcore/policy_anonymous_credential_test.go index e625470dbcf7..3f6d88a743d5 100644 --- a/sdk/azcore/policy_anonymous_credential_test.go +++ b/sdk/azcore/policy_anonymous_credential_test.go @@ -12,6 +12,7 @@ import ( "reflect" "testing" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/runtime" "github.com/Azure/azure-sdk-for-go/sdk/internal/mock" ) @@ -19,8 +20,8 @@ func TestAnonymousCredential(t *testing.T) { srv, close := mock.NewServer() defer close() srv.SetResponse(mock.WithStatusCode(http.StatusOK)) - pl := NewPipeline(srv, NewAnonymousCredential().NewAuthenticationPolicy(AuthenticationOptions{})) - req, err := NewRequest(context.Background(), http.MethodGet, srv.URL()) + pl := runtime.NewPipeline(srv, NewAnonymousCredential().NewAuthenticationPolicy(runtime.AuthenticationOptions{})) + req, err := runtime.NewRequest(context.Background(), http.MethodGet, srv.URL()) if err != nil { t.Fatalf("unexpected error: %v", err) } @@ -28,7 +29,7 @@ func TestAnonymousCredential(t *testing.T) { if err != nil { t.Fatalf("unexpected error: %v", err) } - if !reflect.DeepEqual(req.Header, resp.Request.Header) { + if !reflect.DeepEqual(req.Raw().Header, resp.Request.Header) { t.Fatal("unexpected modification to request headers") } } diff --git a/sdk/azcore/policy_http_header.go b/sdk/azcore/policy_http_header.go deleted file mode 100644 index ba7650280d9f..000000000000 --- a/sdk/azcore/policy_http_header.go +++ /dev/null @@ -1,39 +0,0 @@ -//go:build go1.16 -// +build go1.16 - -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -package azcore - -import ( - "context" - "net/http" -) - -// used as a context key for adding/retrieving http.Header -type ctxWithHTTPHeader struct{} - -// newHTTPHeaderPolicy creates a policy object that adds custom HTTP headers to a request -func httpHeaderPolicy(req *Request) (*http.Response, error) { - // check if any custom HTTP headers have been specified - if header := req.Context().Value(ctxWithHTTPHeader{}); header != nil { - for k, v := range header.(http.Header) { - // use Set to replace any existing value - // it also canonicalizes the header key - req.Header.Set(k, v[0]) - // add any remaining values - for i := 1; i < len(v); i++ { - req.Header.Add(k, v[i]) - } - } - } - return req.Next() -} - -// WithHTTPHeader adds the specified http.Header to the parent context. -// Use this to specify custom HTTP headers at the API-call level. -// Any overlapping headers will have their values replaced with the values specified here. -func WithHTTPHeader(parent context.Context, header http.Header) context.Context { - return context.WithValue(parent, ctxWithHTTPHeader{}, header) -} diff --git a/sdk/azcore/poller.go b/sdk/azcore/poller.go deleted file mode 100644 index 340cb0859084..000000000000 --- a/sdk/azcore/poller.go +++ /dev/null @@ -1,448 +0,0 @@ -//go:build go1.16 -// +build go1.16 - -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -package azcore - -import ( - "bytes" - "context" - "encoding/json" - "errors" - "fmt" - "io/ioutil" - "net/http" - "strconv" - "strings" - "time" - - "github.com/Azure/azure-sdk-for-go/sdk/internal/log" -) - -// NewPoller creates a Poller based on the provided initial response. -// pollerID - a unique identifier for an LRO, it's usually the client.Method string. -// NOTE: this is only meant for internal use in generated code. -func NewPoller(pollerID string, resp *http.Response, pl Pipeline, eu func(*http.Response) error) (*Poller, error) { - // this is a back-stop in case the swagger is incorrect (i.e. missing one or more status codes for success). - // ideally the codegen should return an error if the initial response failed and not even create a poller. - if !lroStatusCodeValid(resp) { - return nil, errors.New("the operation failed or was cancelled") - } - opLoc := resp.Header.Get(headerOperationLocation) - loc := resp.Header.Get(headerLocation) - // in the case of both headers, always prefer the operation-location header - if opLoc != "" { - return &Poller{ - lro: newOpPoller(pollerID, opLoc, loc, resp), - pl: pl, - eu: eu, - resp: resp, - }, nil - } - if loc != "" { - return &Poller{ - lro: newLocPoller(pollerID, loc, resp.StatusCode), - pl: pl, - eu: eu, - resp: resp, - }, nil - } - return &Poller{lro: &nopPoller{}, resp: resp}, nil -} - -// NewPollerFromResumeToken creates a Poller from a resume token string. -// pollerID - a unique identifier for an LRO, it's usually the client.Method string. -// NOTE: this is only meant for internal use in generated code. -func NewPollerFromResumeToken(pollerID string, token string, pl Pipeline, eu func(*http.Response) error) (*Poller, error) { - // unmarshal into JSON object to determine the poller type - obj := map[string]interface{}{} - err := json.Unmarshal([]byte(token), &obj) - if err != nil { - return nil, err - } - t, ok := obj["type"] - if !ok { - return nil, errors.New("missing type field") - } - tt, ok := t.(string) - if !ok { - return nil, fmt.Errorf("invalid type format %T", t) - } - // the type is encoded as "pollerType;lroPoller" - sem := strings.LastIndex(tt, ";") - if sem < 0 { - return nil, fmt.Errorf("invalid poller type %s", tt) - } - // ensure poller types match - if received := tt[:sem]; received != pollerID { - return nil, fmt.Errorf("cannot resume from this poller token. expected %s, received %s", pollerID, received) - } - // now rehydrate the poller based on the encoded poller type - var lro lroPoller - switch pt := tt[sem+1:]; pt { - case "opPoller": - lro = &opPoller{} - case "locPoller": - lro = &locPoller{} - default: - return nil, fmt.Errorf("unhandled lroPoller type %s", pt) - } - if err = json.Unmarshal([]byte(token), lro); err != nil { - return nil, err - } - return &Poller{lro: lro, pl: pl, eu: eu}, nil -} - -// Poller encapsulates state and logic for polling on long-running operations. -// NOTE: this is only meant for internal use in generated code. -type Poller struct { - lro lroPoller - pl Pipeline - eu func(*http.Response) error - resp *http.Response - err error -} - -// Done returns true if the LRO has reached a terminal state. -func (l *Poller) Done() bool { - if l.err != nil { - return true - } - return l.lro.Done() -} - -// Poll sends a polling request to the polling endpoint and returns the response or error. -func (l *Poller) Poll(ctx context.Context) (*http.Response, error) { - if l.Done() { - // the LRO has reached a terminal state, don't poll again - if l.resp != nil { - return l.resp, nil - } - return nil, l.err - } - req, err := NewRequest(ctx, http.MethodGet, l.lro.URL()) - if err != nil { - return nil, err - } - resp, err := l.pl.Do(req) - if err != nil { - // don't update the poller for failed requests - return nil, err - } - if !lroStatusCodeValid(resp) { - // the LRO failed. unmarshall the error and update state - l.err = l.eu(resp) - l.resp = nil - return nil, l.err - } - if err = l.lro.Update(resp); err != nil { - return nil, err - } - l.resp = resp - return l.resp, nil -} - -// ResumeToken returns a token string that can be used to resume a poller that has not yet reached a terminal state. -func (l *Poller) ResumeToken() (string, error) { - if l.Done() { - return "", errors.New("cannot create a ResumeToken from a poller in a terminal state") - } - b, err := json.Marshal(l.lro) - if err != nil { - return "", err - } - return string(b), nil -} - -// FinalResponse will perform a final GET request and return the final HTTP response for the polling -// operation and unmarshall the content of the payload into the respType interface that is provided. -func (l *Poller) FinalResponse(ctx context.Context, respType interface{}) (*http.Response, error) { - if !l.Done() { - return nil, errors.New("cannot return a final response from a poller in a non-terminal state") - } - // if there's nothing to unmarshall into just return the final response - if respType == nil { - return l.resp, nil - } - u, err := l.lro.FinalGetURL(l.resp) - if err != nil { - return nil, err - } - if u != "" { - req, err := NewRequest(ctx, http.MethodGet, u) - if err != nil { - return nil, err - } - resp, err := l.pl.Do(req) - if err != nil { - return nil, err - } - if !lroStatusCodeValid(resp) { - return nil, l.eu(resp) - } - l.resp = resp - } - body, err := ioutil.ReadAll(l.resp.Body) - l.resp.Body.Close() - if err != nil { - return nil, err - } - if err = json.Unmarshal(body, respType); err != nil { - return nil, err - } - return l.resp, nil -} - -// PollUntilDone will handle the entire span of the polling operation until a terminal state is reached, -// then return the final HTTP response for the polling operation and unmarshal the content of the payload -// into the respType interface that is provided. -func (l *Poller) PollUntilDone(ctx context.Context, freq time.Duration, respType interface{}) (*http.Response, error) { - logPollUntilDoneExit := func(v interface{}) { - log.Writef(log.LongRunningOperation, "END PollUntilDone() for %T: %v", l.lro, v) - } - log.Writef(log.LongRunningOperation, "BEGIN PollUntilDone() for %T", l.lro) - if l.resp != nil { - // initial check for a retry-after header existing on the initial response - if retryAfter := RetryAfter(l.resp); retryAfter > 0 { - log.Writef(log.LongRunningOperation, "initial Retry-After delay for %s", retryAfter.String()) - if err := delay(ctx, retryAfter); err != nil { - logPollUntilDoneExit(err) - return nil, err - } - } - } - // begin polling the endpoint until a terminal state is reached - for { - resp, err := l.Poll(ctx) - if err != nil { - logPollUntilDoneExit(err) - return nil, err - } - if l.Done() { - logPollUntilDoneExit(l.lro.Status()) - if !l.lro.Succeeded() { - return nil, l.eu(resp) - } - return l.FinalResponse(ctx, respType) - } - d := freq - if retryAfter := RetryAfter(resp); retryAfter > 0 { - log.Writef(log.LongRunningOperation, "Retry-After delay for %s", retryAfter.String()) - d = retryAfter - } else { - log.Writef(log.LongRunningOperation, "delay for %s", d.String()) - } - if err = delay(ctx, d); err != nil { - logPollUntilDoneExit(err) - return nil, err - } - } -} - -// abstracts the differences between concrete poller types -type lroPoller interface { - Done() bool - Update(resp *http.Response) error - FinalGetURL(resp *http.Response) (string, error) - URL() string - Status() string - Succeeded() bool -} - -// ==================================================================================================== - -// polls on the operation-location header -type opPoller struct { - Type string `json:"type"` - ReqMethod string `json:"reqMethod"` - ReqURL string `json:"reqURL"` - PollURL string `json:"pollURL"` - LocURL string `json:"locURL"` - status string -} - -func newOpPoller(pollerType, pollingURL, locationURL string, initialResponse *http.Response) *opPoller { - return &opPoller{ - Type: fmt.Sprintf("%s;opPoller", pollerType), - ReqMethod: initialResponse.Request.Method, - ReqURL: initialResponse.Request.URL.String(), - PollURL: pollingURL, - LocURL: locationURL, - } -} - -func (p *opPoller) URL() string { - return p.PollURL -} - -func (p *opPoller) Done() bool { - return strings.EqualFold(p.status, "succeeded") || - strings.EqualFold(p.status, "failed") || - strings.EqualFold(p.status, "cancelled") -} - -func (p *opPoller) Succeeded() bool { - return strings.EqualFold(p.status, "succeeded") -} - -func (p *opPoller) Update(resp *http.Response) error { - status, err := extractJSONValue(resp, "status") - if err != nil { - return err - } - if status == "" { - return errors.New("no status found in body") - } - p.status = status - // if the endpoint returned an operation-location header, update cached value - if opLoc := resp.Header.Get(headerOperationLocation); opLoc != "" { - p.PollURL = opLoc - } - return nil -} - -func (p *opPoller) FinalGetURL(resp *http.Response) (string, error) { - if !p.Done() { - return "", errors.New("cannot return a final response from a poller in a non-terminal state") - } - resLoc, err := extractJSONValue(resp, "resourceLocation") - if err != nil { - return "", err - } - if resLoc != "" { - return resLoc, nil - } - if p.ReqMethod == http.MethodPatch || p.ReqMethod == http.MethodPut { - return p.ReqURL, nil - } - if p.ReqMethod == http.MethodPost && p.LocURL != "" { - return p.LocURL, nil - } - return "", nil -} - -func (p *opPoller) Status() string { - return p.status -} - -// ==================================================================================================== - -// polls on the location header -type locPoller struct { - Type string `json:"type"` - PollURL string `json:"pollURL"` - status int -} - -func newLocPoller(pollerType, pollingURL string, initialStatus int) *locPoller { - return &locPoller{ - Type: fmt.Sprintf("%s;locPoller", pollerType), - PollURL: pollingURL, - status: initialStatus, - } -} - -func (p *locPoller) URL() string { - return p.PollURL -} - -func (p *locPoller) Done() bool { - // a 202 means the operation is still in progress - // zero-value indicates the poller was rehydrated from a token - return p.status > 0 && p.status != http.StatusAccepted -} - -func (p *locPoller) Succeeded() bool { - // any 2xx status code indicates success - return p.status >= 200 && p.status < 300 -} - -func (p *locPoller) Update(resp *http.Response) error { - // if the endpoint returned a location header, update cached value - if loc := resp.Header.Get(headerLocation); loc != "" { - p.PollURL = loc - } - p.status = resp.StatusCode - return nil -} - -func (*locPoller) FinalGetURL(*http.Response) (string, error) { - return "", nil -} - -func (p *locPoller) Status() string { - return strconv.Itoa(p.status) -} - -// ==================================================================================================== - -// used if the endpoint didn't return any polling headers (synchronous completion) -type nopPoller struct{} - -func (*nopPoller) URL() string { - return "" -} - -func (*nopPoller) Done() bool { - return true -} - -func (*nopPoller) Succeeded() bool { - return true -} - -func (*nopPoller) Update(*http.Response) error { - return nil -} - -func (*nopPoller) FinalGetURL(*http.Response) (string, error) { - return "", nil -} - -func (*nopPoller) Status() string { - return "succeeded" -} - -// returns true if the LRO response contains a valid HTTP status code -func lroStatusCodeValid(resp *http.Response) bool { - return HasStatusCode(resp, http.StatusOK, http.StatusAccepted, http.StatusCreated, http.StatusNoContent) -} - -// extracs a JSON value from the provided reader -func extractJSONValue(resp *http.Response, val string) (string, error) { - if resp.ContentLength == 0 { - return "", errors.New("the response does not contain a body") - } - body, err := ioutil.ReadAll(resp.Body) - if err != nil { - return "", err - } - // put the body back so it's available to our callers - resp.Body = ioutil.NopCloser(bytes.NewReader(body)) - // unmarshall the body to get the value - var jsonBody map[string]interface{} - if err = json.Unmarshal(body, &jsonBody); err != nil { - return "", err - } - v, ok := jsonBody[val] - if !ok { - // it might be ok if the field doesn't exist, the caller must make that determination - return "", nil - } - vv, ok := v.(string) - if !ok { - return "", fmt.Errorf("the %s value %v was not in string format", val, v) - } - return vv, nil -} - -func delay(ctx context.Context, delay time.Duration) error { - select { - case <-time.After(delay): - return nil - case <-ctx.Done(): - return ctx.Err() - } -} diff --git a/sdk/azcore/request.go b/sdk/azcore/request.go deleted file mode 100644 index a3074596a2f8..000000000000 --- a/sdk/azcore/request.go +++ /dev/null @@ -1,398 +0,0 @@ -//go:build go1.16 -// +build go1.16 - -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -package azcore - -import ( - "bytes" - "context" - "encoding/base64" - "encoding/json" - "encoding/xml" - "errors" - "fmt" - "io" - "io/ioutil" - "mime/multipart" - "net/http" - "reflect" - "strconv" - "strings" - - "golang.org/x/net/http/httpguts" -) - -const ( - contentTypeAppJSON = "application/json" - contentTypeAppXML = "application/xml" -) - -// Base64Encoding is usesd to specify which base-64 encoder/decoder to use when -// encoding/decoding a slice of bytes to/from a string. -type Base64Encoding int - -const ( - // Base64StdFormat uses base64.StdEncoding for encoding and decoding payloads. - Base64StdFormat Base64Encoding = 0 - - // Base64URLFormat uses base64.RawURLEncoding for encoding and decoding payloads. - Base64URLFormat Base64Encoding = 1 -) - -// Request is an abstraction over the creation of an HTTP request as it passes through the pipeline. -// Don't use this type directly, use NewRequest() instead. -type Request struct { - *http.Request - body ReadSeekCloser - policies []Policy - values opValues -} - -type opValues map[reflect.Type]interface{} - -// Set adds/changes a value -func (ov opValues) set(value interface{}) { - ov[reflect.TypeOf(value)] = value -} - -// Get looks for a value set by SetValue first -func (ov opValues) get(value interface{}) bool { - v, ok := ov[reflect.ValueOf(value).Elem().Type()] - if ok { - reflect.ValueOf(value).Elem().Set(reflect.ValueOf(v)) - } - return ok -} - -// JoinPaths concatenates multiple URL path segments into one path, -// inserting path separation characters as required. JoinPaths will preserve -// query parameters in the root path -func JoinPaths(root string, paths ...string) string { - if len(paths) == 0 { - return root - } - - qps := "" - if strings.Contains(root, "?") { - splitPath := strings.Split(root, "?") - root, qps = splitPath[0], splitPath[1] - } - - for i := 0; i < len(paths); i++ { - root = strings.TrimRight(root, "/") - paths[i] = strings.TrimLeft(paths[i], "/") - root += "/" + paths[i] - } - - if qps != "" { - if !strings.HasSuffix(root, "/") { - root += "/" - } - return root + "?" + qps - } - return root -} - -// NewRequest creates a new Request with the specified input. -func NewRequest(ctx context.Context, httpMethod string, endpoint string) (*Request, error) { - req, err := http.NewRequestWithContext(ctx, httpMethod, endpoint, nil) - if err != nil { - return nil, err - } - if req.URL.Host == "" { - return nil, errors.New("no Host in request URL") - } - if !(req.URL.Scheme == "http" || req.URL.Scheme == "https") { - return nil, fmt.Errorf("unsupported protocol scheme %s", req.URL.Scheme) - } - return &Request{Request: req}, nil -} - -// Next calls the next policy in the pipeline. -// If there are no more policies, nil and ErrNoMorePolicies are returned. -// This method is intended to be called from pipeline policies. -// To send a request through a pipeline call Pipeline.Do(). -func (req *Request) Next() (*http.Response, error) { - if len(req.policies) == 0 { - return nil, errors.New("no more policies") - } - nextPolicy := req.policies[0] - nextReq := *req - nextReq.policies = nextReq.policies[1:] - return nextPolicy.Do(&nextReq) -} - -// MarshalAsByteArray will base-64 encode the byte slice v, then calls SetBody. -// The encoded value is treated as a JSON string. -func (req *Request) MarshalAsByteArray(v []byte, format Base64Encoding) error { - // send as a JSON string - encode := fmt.Sprintf("\"%s\"", EncodeByteArray(v, format)) - return req.SetBody(NopCloser(strings.NewReader(encode)), contentTypeAppJSON) -} - -// MarshalAsJSON calls json.Marshal() to get the JSON encoding of v then calls SetBody. -func (req *Request) MarshalAsJSON(v interface{}) error { - v = cloneWithoutReadOnlyFields(v) - b, err := json.Marshal(v) - if err != nil { - return fmt.Errorf("error marshalling type %T: %s", v, err) - } - return req.SetBody(NopCloser(bytes.NewReader(b)), contentTypeAppJSON) -} - -// MarshalAsXML calls xml.Marshal() to get the XML encoding of v then calls SetBody. -func (req *Request) MarshalAsXML(v interface{}) error { - b, err := xml.Marshal(v) - if err != nil { - return fmt.Errorf("error marshalling type %T: %s", v, err) - } - return req.SetBody(NopCloser(bytes.NewReader(b)), contentTypeAppXML) -} - -// SetOperationValue adds/changes a mutable key/value associated with a single operation. -func (req *Request) SetOperationValue(value interface{}) { - if req.values == nil { - req.values = opValues{} - } - req.values.set(value) -} - -// OperationValue looks for a value set by SetOperationValue(). -func (req *Request) OperationValue(value interface{}) bool { - if req.values == nil { - return false - } - return req.values.get(value) -} - -// SetBody sets the specified ReadSeekCloser as the HTTP request body. -func (req *Request) SetBody(body ReadSeekCloser, contentType string) error { - // Set the body and content length. - size, err := body.Seek(0, io.SeekEnd) // Seek to the end to get the stream's size - if err != nil { - return err - } - if size == 0 { - body.Close() - return nil - } - _, err = body.Seek(0, io.SeekStart) - if err != nil { - return err - } - // keep a copy of the original body. this is to handle cases - // where req.Body is replaced, e.g. httputil.DumpRequest and friends. - req.body = body - req.Request.Body = body - req.Request.ContentLength = size - req.Header.Set(headerContentType, contentType) - req.Header.Set(headerContentLength, strconv.FormatInt(size, 10)) - return nil -} - -// SetMultipartFormData writes the specified keys/values as multi-part form -// fields with the specified value. File content must be specified as a ReadSeekCloser. -// All other values are treated as string values. -func (req *Request) SetMultipartFormData(formData map[string]interface{}) error { - body := bytes.Buffer{} - writer := multipart.NewWriter(&body) - for k, v := range formData { - if rsc, ok := v.(ReadSeekCloser); ok { - // this is the body to upload, the key is its file name - fd, err := writer.CreateFormFile(k, k) - if err != nil { - return err - } - // copy the data to the form file - if _, err = io.Copy(fd, rsc); err != nil { - return err - } - continue - } - // ensure the value is in string format - s, ok := v.(string) - if !ok { - s = fmt.Sprintf("%v", v) - } - if err := writer.WriteField(k, s); err != nil { - return err - } - } - if err := writer.Close(); err != nil { - return err - } - req.body = NopCloser(bytes.NewReader(body.Bytes())) - req.Body = req.body - req.ContentLength = int64(body.Len()) - req.Header.Set(headerContentType, writer.FormDataContentType()) - req.Header.Set(headerContentLength, strconv.FormatInt(req.ContentLength, 10)) - return nil -} - -// SkipBodyDownload will disable automatic downloading of the response body. -func (req *Request) SkipBodyDownload() { - req.SetOperationValue(bodyDownloadPolicyOpValues{skip: true}) -} - -// RewindBody seeks the request's Body stream back to the beginning so it can be resent when retrying an operation. -func (req *Request) RewindBody() error { - if req.body != nil { - // Reset the stream back to the beginning and restore the body - _, err := req.body.Seek(0, io.SeekStart) - req.Body = req.body - return err - } - return nil -} - -// Close closes the request body. -func (req *Request) Close() error { - if req.Body == nil { - return nil - } - return req.Body.Close() -} - -// Telemetry adds telemetry data to the request. -// If telemetry reporting is disabled the value is discarded. -func (req *Request) Telemetry(v string) { - req.SetOperationValue(requestTelemetry(v)) -} - -// clone returns a deep copy of the request with its context changed to ctx -func (req *Request) clone(ctx context.Context) *Request { - r2 := Request{} - r2 = *req - r2.Request = req.Request.Clone(ctx) - return &r2 -} - -// valid returns nil if the underlying http.Request is well-formed. -func (req *Request) valid() error { - // check copied from Transport.roundTrip() - for k, vv := range req.Header { - if !httpguts.ValidHeaderFieldName(k) { - req.Close() - return fmt.Errorf("invalid header field name %q", k) - } - for _, v := range vv { - if !httpguts.ValidHeaderFieldValue(v) { - req.Close() - return fmt.Errorf("invalid header field value %q for key %v", v, k) - } - } - } - return nil -} - -// writes to a buffer, used for logging purposes -func (req *Request) writeBody(b *bytes.Buffer) error { - if req.Body == nil { - fmt.Fprint(b, " Request contained no body\n") - return nil - } - if ct := req.Header.Get(headerContentType); !shouldLogBody(b, ct) { - return nil - } - body, err := ioutil.ReadAll(req.Body) - if err != nil { - fmt.Fprintf(b, " Failed to read request body: %s\n", err.Error()) - return err - } - if err := req.RewindBody(); err != nil { - return err - } - logBody(b, body) - return nil -} - -// EncodeByteArray will base-64 encode the byte slice v. -func EncodeByteArray(v []byte, format Base64Encoding) string { - if format == Base64URLFormat { - return base64.RawURLEncoding.EncodeToString(v) - } - return base64.StdEncoding.EncodeToString(v) -} - -// returns a clone of the object graph pointed to by v, omitting values of all read-only -// fields. if there are no read-only fields in the object graph, no clone is created. -func cloneWithoutReadOnlyFields(v interface{}) interface{} { - val := reflect.Indirect(reflect.ValueOf(v)) - if val.Kind() != reflect.Struct { - // not a struct, skip - return v - } - // first walk the graph to find any R/O fields. - // if there aren't any, skip cloning the graph. - if !recursiveFindReadOnlyField(val) { - return v - } - return recursiveCloneWithoutReadOnlyFields(val) -} - -// returns true if any field in the object graph of val contains the `azure:"ro"` tag value -func recursiveFindReadOnlyField(val reflect.Value) bool { - t := val.Type() - // iterate over the fields, looking for the "azure" tag. - for i := 0; i < t.NumField(); i++ { - field := t.Field(i) - aztag := field.Tag.Get("azure") - if azureTagIsReadOnly(aztag) { - return true - } else if reflect.Indirect(val.Field(i)).Kind() == reflect.Struct && recursiveFindReadOnlyField(reflect.Indirect(val.Field(i))) { - return true - } - } - return false -} - -// clones the object graph of val. all non-R/O properties are copied to the clone -func recursiveCloneWithoutReadOnlyFields(val reflect.Value) interface{} { - clone := reflect.New(val.Type()) - t := val.Type() - // iterate over the fields, looking for the "azure" tag. - for i := 0; i < t.NumField(); i++ { - field := t.Field(i) - aztag := field.Tag.Get("azure") - if azureTagIsReadOnly(aztag) { - // omit from payload - } else if reflect.Indirect(val.Field(i)).Kind() == reflect.Struct { - // recursive case - v := recursiveCloneWithoutReadOnlyFields(reflect.Indirect(val.Field(i))) - if t.Field(i).Anonymous { - // NOTE: this does not handle the case of embedded fields of unexported struct types. - // this should be ok as we don't generate any code like this at present - reflect.Indirect(clone).Field(i).Set(reflect.Indirect(reflect.ValueOf(v))) - } else { - reflect.Indirect(clone).Field(i).Set(reflect.ValueOf(v)) - } - } else { - // no azure RO tag, non-recursive case, include in payload - reflect.Indirect(clone).Field(i).Set(val.Field(i)) - } - } - return clone.Interface() -} - -// returns true if the "azure" tag contains the option "ro" -func azureTagIsReadOnly(tag string) bool { - if tag == "" { - return false - } - parts := strings.Split(tag, ",") - for _, part := range parts { - if part == "ro" { - return true - } - } - return false -} - -func logBody(b *bytes.Buffer, body []byte) { - fmt.Fprintln(b, " --------------------------------------------------------------------------------") - fmt.Fprintln(b, string(body)) - fmt.Fprintln(b, " --------------------------------------------------------------------------------") -} diff --git a/sdk/azcore/runtime/errors.go b/sdk/azcore/runtime/errors.go new file mode 100644 index 000000000000..badf62a3bd37 --- /dev/null +++ b/sdk/azcore/runtime/errors.go @@ -0,0 +1,21 @@ +//go:build go1.16 +// +build go1.16 + +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package runtime + +import ( + "net/http" + + "github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/shared" +) + +// NewResponseError wraps the specified error with an error that provides access to an HTTP response. +// If an HTTP request returns a non-successful status code, wrap the response and the associated error +// in this error type so that callers can access the underlying *http.Response as required. +// DO NOT wrap failed HTTP requests that returned an error and no response with this type. +func NewResponseError(inner error, resp *http.Response) error { + return shared.NewResponseError(inner, resp) +} diff --git a/sdk/azcore/policy_body_download.go b/sdk/azcore/runtime/policy_body_download.go similarity index 80% rename from sdk/azcore/policy_body_download.go rename to sdk/azcore/runtime/policy_body_download.go index 8724b11f10b8..7fca2ba044d5 100644 --- a/sdk/azcore/policy_body_download.go +++ b/sdk/azcore/runtime/policy_body_download.go @@ -4,7 +4,7 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -package azcore +package runtime import ( "errors" @@ -13,17 +13,21 @@ import ( "io/ioutil" "net/http" "strings" + + "github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/shared" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/policy" + "github.com/Azure/azure-sdk-for-go/sdk/internal/errorinfo" ) // bodyDownloadPolicy creates a policy object that downloads the response's body to a []byte. -func bodyDownloadPolicy(req *Request) (*http.Response, error) { +func bodyDownloadPolicy(req *policy.Request) (*http.Response, error) { resp, err := req.Next() if err != nil { return resp, err } - var opValues bodyDownloadPolicyOpValues + var opValues shared.BodyDownloadPolicyOpValues // don't skip downloading error response bodies - if req.OperationValue(&opValues); opValues.skip && resp.StatusCode < 400 { + if req.OperationValue(&opValues); opValues.Skip && resp.StatusCode < 400 { return resp, err } // Either bodyDownloadPolicyOpValues was not specified (so skip is false) @@ -41,10 +45,10 @@ type bodyDownloadError struct { err error } -func newBodyDownloadError(err error, req *Request) error { +func newBodyDownloadError(err error, req *policy.Request) error { // on failure, only retry the request for idempotent operations. // we currently identify them as DELETE, GET, and PUT requests. - if m := strings.ToUpper(req.Method); m == http.MethodDelete || m == http.MethodGet || m == http.MethodPut { + if m := strings.ToUpper(req.Raw().Method); m == http.MethodDelete || m == http.MethodGet || m == http.MethodPut { // error is safe for retry return err } @@ -66,12 +70,7 @@ func (b *bodyDownloadError) Unwrap() error { return b.err } -var _ NonRetriableError = (*bodyDownloadError)(nil) - -// bodyDownloadPolicyOpValues is the struct containing the per-operation values -type bodyDownloadPolicyOpValues struct { - skip bool -} +var _ errorinfo.NonRetriable = (*bodyDownloadError)(nil) // nopClosingBytesReader is an io.ReadSeekCloser around a byte slice. // It also provides direct access to the byte slice. diff --git a/sdk/azcore/policy_body_download_test.go b/sdk/azcore/runtime/policy_body_download_test.go similarity index 99% rename from sdk/azcore/policy_body_download_test.go rename to sdk/azcore/runtime/policy_body_download_test.go index a7387cbedd4f..e4e70fa42779 100644 --- a/sdk/azcore/policy_body_download_test.go +++ b/sdk/azcore/runtime/policy_body_download_test.go @@ -4,7 +4,7 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -package azcore +package runtime import ( "context" diff --git a/sdk/azcore/runtime/policy_http_header.go b/sdk/azcore/runtime/policy_http_header.go new file mode 100644 index 000000000000..148c6d9a313d --- /dev/null +++ b/sdk/azcore/runtime/policy_http_header.go @@ -0,0 +1,31 @@ +//go:build go1.16 +// +build go1.16 + +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package runtime + +import ( + "net/http" + + "github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/shared" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/policy" +) + +// newHTTPHeaderPolicy creates a policy object that adds custom HTTP headers to a request +func httpHeaderPolicy(req *policy.Request) (*http.Response, error) { + // check if any custom HTTP headers have been specified + if header := req.Raw().Context().Value(shared.CtxWithHTTPHeaderKey{}); header != nil { + for k, v := range header.(http.Header) { + // use Set to replace any existing value + // it also canonicalizes the header key + req.Raw().Header.Set(k, v[0]) + // add any remaining values + for i := 1; i < len(v); i++ { + req.Raw().Header.Add(k, v[i]) + } + } + } + return req.Next() +} diff --git a/sdk/azcore/policy_http_header_test.go b/sdk/azcore/runtime/policy_http_header_test.go similarity index 91% rename from sdk/azcore/policy_http_header_test.go rename to sdk/azcore/runtime/policy_http_header_test.go index ededd5955d17..2e9aef013492 100644 --- a/sdk/azcore/policy_http_header_test.go +++ b/sdk/azcore/runtime/policy_http_header_test.go @@ -4,7 +4,7 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -package azcore +package runtime import ( "context" @@ -12,6 +12,7 @@ import ( "net/textproto" "testing" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/policy" "github.com/Azure/azure-sdk-for-go/sdk/internal/mock" ) @@ -29,13 +30,13 @@ func TestAddCustomHTTPHeaderSuccess(t *testing.T) { srv.AppendResponse(mock.WithStatusCode(http.StatusBadRequest)) // HTTP header policy is automatically added during pipeline construction pl := NewPipeline(srv) - req, err := NewRequest(WithHTTPHeader(context.Background(), http.Header{ + req, err := NewRequest(policy.WithHTTPHeader(context.Background(), http.Header{ customHeader: []string{customValue}, }), http.MethodGet, srv.URL()) if err != nil { t.Fatalf("unexpected error: %v", err) } - req.Header.Set(preexistingHeader, preexistingValue) + req.Raw().Header.Set(preexistingHeader, preexistingValue) resp, err := pl.Do(req) if err != nil { t.Fatalf("unexpected error: %v", err) @@ -81,7 +82,7 @@ func TestAddCustomHTTPHeaderOverwrite(t *testing.T) { // HTTP header policy is automatically added during pipeline construction pl := NewPipeline(srv) // overwrite the request ID with our own value - req, err := NewRequest(WithHTTPHeader(context.Background(), http.Header{ + req, err := NewRequest(policy.WithHTTPHeader(context.Background(), http.Header{ customHeader: []string{customValue}, }), http.MethodGet, srv.URL()) if err != nil { @@ -112,7 +113,7 @@ func TestAddCustomHTTPHeaderMultipleValues(t *testing.T) { // HTTP header policy is automatically added during pipeline construction pl := NewPipeline(srv) // overwrite the request ID with our own value - req, err := NewRequest(WithHTTPHeader(context.Background(), http.Header{ + req, err := NewRequest(policy.WithHTTPHeader(context.Background(), http.Header{ customHeader: []string{customValue1, customValue2}, }), http.MethodGet, srv.URL()) if err != nil { diff --git a/sdk/azcore/policy_logging.go b/sdk/azcore/runtime/policy_logging.go similarity index 60% rename from sdk/azcore/policy_logging.go rename to sdk/azcore/runtime/policy_logging.go index b97189f3757b..b013a0c45f1a 100644 --- a/sdk/azcore/policy_logging.go +++ b/sdk/azcore/runtime/policy_logging.go @@ -4,36 +4,31 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -package azcore +package runtime import ( "bytes" "fmt" + "io/ioutil" "net/http" "strings" "time" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/shared" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/policy" "github.com/Azure/azure-sdk-for-go/sdk/internal/diag" "github.com/Azure/azure-sdk-for-go/sdk/internal/log" ) -// LogOptions configures the logging policy's behavior. -type LogOptions struct { - // IncludeBody indicates if request and response bodies should be included in logging. - // The default value is false. - // NOTE: enabling this can lead to disclosure of sensitive information, use with care. - IncludeBody bool -} - type logPolicy struct { - options LogOptions + options policy.LogOptions } // NewLogPolicy creates a RequestLogPolicy object configured using the specified options. // Pass nil to accept the default values; this is the same as passing a zero-value options. -func NewLogPolicy(o *LogOptions) Policy { +func NewLogPolicy(o *policy.LogOptions) policy.Policy { if o == nil { - o = &LogOptions{} + o = &policy.LogOptions{} } return &logPolicy{options: *o} } @@ -44,7 +39,7 @@ type logPolicyOpValues struct { start time.Time } -func (p *logPolicy) Do(req *Request) (*http.Response, error) { +func (p *logPolicy) Do(req *policy.Request) (*http.Response, error) { // Get the per-operation values. These are saved in the Message's map so that they persist across each retry calling into this policy object. var opValues logPolicyOpValues if req.OperationValue(&opValues); opValues.start.IsZero() { @@ -60,7 +55,7 @@ func (p *logPolicy) Do(req *Request) (*http.Response, error) { writeRequestWithResponse(b, req, nil, nil) var err error if p.options.IncludeBody { - err = req.writeBody(b) + err = writeReqBody(req, b) } log.Write(log.Request, b.String()) if err != nil { @@ -88,9 +83,9 @@ func (p *logPolicy) Do(req *Request) (*http.Response, error) { writeRequestWithResponse(b, req, response, err) if err != nil { // skip frames runtime.Callers() and runtime.StackTrace() - b.WriteString(diag.StackTrace(2, StackFrameCount)) + b.WriteString(diag.StackTrace(2, 32)) } else if p.options.IncludeBody { - err = writeBody(response, b) + err = writeRespBody(response, b) } log.Write(log.Response, b.String()) } @@ -109,3 +104,52 @@ func shouldLogBody(b *bytes.Buffer, contentType string) bool { fmt.Fprintf(b, " Skip logging body for %s\n", contentType) return false } + +// writes to a buffer, used for logging purposes +func writeReqBody(req *policy.Request, b *bytes.Buffer) error { + if req.Raw().Body == nil { + fmt.Fprint(b, " Request contained no body\n") + return nil + } + if ct := req.Raw().Header.Get(shared.HeaderContentType); !shouldLogBody(b, ct) { + return nil + } + body, err := ioutil.ReadAll(req.Raw().Body) + if err != nil { + fmt.Fprintf(b, " Failed to read request body: %s\n", err.Error()) + return err + } + if err := req.RewindBody(); err != nil { + return err + } + logBody(b, body) + return nil +} + +// writes to a buffer, used for logging purposes +func writeRespBody(resp *http.Response, b *bytes.Buffer) error { + ct := resp.Header.Get(shared.HeaderContentType) + if ct == "" { + fmt.Fprint(b, " Response contained no body\n") + return nil + } else if !shouldLogBody(b, ct) { + return nil + } + body, err := Payload(resp) + if err != nil { + fmt.Fprintf(b, " Failed to read response body: %s\n", err.Error()) + return err + } + if len(body) > 0 { + logBody(b, body) + } else { + fmt.Fprint(b, " Response contained no body\n") + } + return nil +} + +func logBody(b *bytes.Buffer, body []byte) { + fmt.Fprintln(b, " --------------------------------------------------------------------------------") + fmt.Fprintln(b, string(body)) + fmt.Fprintln(b, " --------------------------------------------------------------------------------") +} diff --git a/sdk/azcore/policy_logging_test.go b/sdk/azcore/runtime/policy_logging_test.go similarity index 95% rename from sdk/azcore/policy_logging_test.go rename to sdk/azcore/runtime/policy_logging_test.go index 6524ae978ce9..6ef960ce775f 100644 --- a/sdk/azcore/policy_logging_test.go +++ b/sdk/azcore/runtime/policy_logging_test.go @@ -4,7 +4,7 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -package azcore +package runtime import ( "bytes" @@ -31,10 +31,10 @@ func TestPolicyLoggingSuccess(t *testing.T) { if err != nil { t.Fatalf("unexpected error: %v", err) } - qp := req.URL.Query() + qp := req.Raw().URL.Query() qp.Set("one", "fish") qp.Set("sig", "redact") - req.URL.RawQuery = qp.Encode() + req.Raw().URL.RawQuery = qp.Encode() resp, err := pl.Do(req) if err != nil { t.Fatalf("unexpected error: %v", err) @@ -81,8 +81,8 @@ func TestPolicyLoggingError(t *testing.T) { if err != nil { t.Fatalf("unexpected error: %v", err) } - req.Header.Add("header", "one") - req.Header.Add("Authorization", "redact") + req.Raw().Header.Add("header", "one") + req.Raw().Header.Add("Authorization", "redact") resp, err := pl.Do(req) if err == nil { t.Fatal("unexpected nil error") diff --git a/sdk/azcore/policy_retry.go b/sdk/azcore/runtime/policy_retry.go similarity index 60% rename from sdk/azcore/policy_retry.go rename to sdk/azcore/runtime/policy_retry.go index b916f28fbadd..55eedd1a5652 100644 --- a/sdk/azcore/policy_retry.go +++ b/sdk/azcore/runtime/policy_retry.go @@ -4,7 +4,7 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -package azcore +package runtime import ( "context" @@ -15,46 +15,15 @@ import ( "net/http" "time" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/shared" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/policy" + "github.com/Azure/azure-sdk-for-go/sdk/internal/errorinfo" "github.com/Azure/azure-sdk-for-go/sdk/internal/log" ) -const ( - defaultMaxRetries = 3 -) - -// RetryOptions configures the retry policy's behavior. -// All zero-value fields will be initialized with their default values. -type RetryOptions struct { - // MaxRetries specifies the maximum number of attempts a failed operation will be retried - // before producing an error. - // The default value is three. A value less than zero means one try and no retries. - MaxRetries int32 - - // TryTimeout indicates the maximum time allowed for any single try of an HTTP request. - // This is disabled by default. Specify a value greater than zero to enable. - // NOTE: Setting this to a small value might cause premature HTTP request time-outs. - TryTimeout time.Duration - - // RetryDelay specifies the initial amount of delay to use before retrying an operation. - // The delay increases exponentially with each retry up to the maximum specified by MaxRetryDelay. - // The default value is four seconds. A value less than zero means no delay between retries. - RetryDelay time.Duration - - // MaxRetryDelay specifies the maximum delay allowed before retrying an operation. - // Typically the value is greater than or equal to the value specified in RetryDelay. - // The default Value is 120 seconds. A value less than zero means there is no cap. - MaxRetryDelay time.Duration - - // StatusCodes specifies the HTTP status codes that indicate the operation should be retried. - // The default value is the status codes in StatusCodesForRetry. - // Specifying an empty slice will cause retries to happen only for transport errors. - StatusCodes []int -} - -// init sets any default values -func (o *RetryOptions) init() { +func setDefaults(o *policy.RetryOptions) { if o.MaxRetries == 0 { - o.MaxRetries = defaultMaxRetries + o.MaxRetries = shared.DefaultMaxRetries } else if o.MaxRetries < 0 { o.MaxRetries = 0 } @@ -80,17 +49,7 @@ func (o *RetryOptions) init() { } } -// used as a context key for adding/retrieving RetryOptions -type ctxWithRetryOptionsKey struct{} - -// WithRetryOptions adds the specified RetryOptions to the parent context. -// Use this to specify custom RetryOptions at the API-call level. -func WithRetryOptions(parent context.Context, options RetryOptions) context.Context { - options.init() - return context.WithValue(parent, ctxWithRetryOptionsKey{}, options) -} - -func (o RetryOptions) calcDelay(try int32) time.Duration { // try is >=1; never 0 +func calcDelay(o policy.RetryOptions, try int32) time.Duration { // try is >=1; never 0 pow := func(number int64, exponent int32) int64 { // pow is nested helper function var result int64 = 1 for n := int32(0); n < exponent; n++ { @@ -111,42 +70,38 @@ func (o RetryOptions) calcDelay(try int32) time.Duration { // try is >=1; never // NewRetryPolicy creates a policy object configured using the specified options. // Pass nil to accept the default values; this is the same as passing a zero-value options. -func NewRetryPolicy(o *RetryOptions) Policy { +func NewRetryPolicy(o *policy.RetryOptions) policy.Policy { if o == nil { - o = &RetryOptions{} + o = &policy.RetryOptions{} } p := &retryPolicy{options: *o} - // fix up values in the copy - p.options.init() return p } type retryPolicy struct { - options RetryOptions + options policy.RetryOptions } -func (p *retryPolicy) Do(req *Request) (resp *http.Response, err error) { +func (p *retryPolicy) Do(req *policy.Request) (resp *http.Response, err error) { options := p.options // check if the retry options have been overridden for this call - if override := req.Context().Value(ctxWithRetryOptionsKey{}); override != nil { - options = override.(RetryOptions) + if override := req.Raw().Context().Value(shared.CtxWithRetryOptionsKey{}); override != nil { + options = override.(policy.RetryOptions) } + setDefaults(&options) // Exponential retry algorithm: ((2 ^ attempt) - 1) * delay * random(0.8, 1.2) // When to retry: connection failure or temporary/timeout. - if req.body != nil { - // wrap the body so we control when it's actually closed - rwbody := &retryableRequestBody{body: req.body} - req.body = rwbody - req.Request.GetBody = func() (io.ReadCloser, error) { - _, err := rwbody.Seek(0, io.SeekStart) // Seek back to the beginning of the stream - return rwbody, err - } + var rwbody *retryableRequestBody + if req.Body() != nil { + // wrap the body so we control when it's actually closed. + // do this outside the for loop so defers don't accumulate. + rwbody = &retryableRequestBody{body: req.Body()} defer rwbody.realClose() } try := int32(1) for { resp = nil // reset - log.Writef(log.RetryPolicy, "\n=====> Try=%d %s %s", try, req.Method, req.URL.String()) + log.Writef(log.RetryPolicy, "\n=====> Try=%d %s %s", try, req.Raw().Method, req.Raw().URL.String()) // For each try, seek to the beginning of the Body stream. We do this even for the 1st try because // the stream may not be at offset 0 when we first get it and we want the same behavior for the @@ -155,13 +110,17 @@ func (p *retryPolicy) Do(req *Request) (resp *http.Response, err error) { if err != nil { return } + // RewindBody() restores Raw().Body to its original state, so set our rewindable after + if rwbody != nil { + req.Raw().Body = rwbody + } if options.TryTimeout == 0 { resp, err = req.Next() } else { // Set the per-try time for this particular retry operation and then Do the operation. - tryCtx, tryCancel := context.WithTimeout(req.Context(), options.TryTimeout) - clone := req.clone(tryCtx) + tryCtx, tryCancel := context.WithTimeout(req.Raw().Context(), options.TryTimeout) + clone := req.Clone(tryCtx) resp, err = clone.Next() // Make the request tryCancel() } @@ -174,7 +133,7 @@ func (p *retryPolicy) Do(req *Request) (resp *http.Response, err error) { if err == nil && !HasStatusCode(resp, options.StatusCodes...) { // if there is no error and the response code isn't in the list of retry codes then we're done. return - } else if ctxErr := req.Context().Err(); ctxErr != nil { + } else if ctxErr := req.Raw().Context().Err(); ctxErr != nil { // don't retry if the parent context has been cancelled or its deadline exceeded err = ctxErr log.Writef(log.RetryPolicy, "abort due to %v", err) @@ -182,7 +141,7 @@ func (p *retryPolicy) Do(req *Request) (resp *http.Response, err error) { } // check if the error is not retriable - var nre NonRetriableError + var nre errorinfo.NonRetriable if errors.As(err, &nre) { // the error says it's not retriable so don't retry log.Writef(log.RetryPolicy, "non-retriable error %T", nre) @@ -199,16 +158,16 @@ func (p *retryPolicy) Do(req *Request) (resp *http.Response, err error) { Drain(resp) // use the delay from retry-after if available - delay := RetryAfter(resp) + delay := shared.RetryAfter(resp) if delay <= 0 { - delay = options.calcDelay(try) + delay = calcDelay(options, try) } log.Writef(log.RetryPolicy, "End Try #%d, Delay=%v", try, delay) select { case <-time.After(delay): try++ - case <-req.Context().Done(): - err = req.Context().Err() + case <-req.Raw().Context().Done(): + err = req.Raw().Context().Err() log.Writef(log.RetryPolicy, "abort due to %v", err) return } diff --git a/sdk/azcore/policy_retry_test.go b/sdk/azcore/runtime/policy_retry_test.go similarity index 90% rename from sdk/azcore/policy_retry_test.go rename to sdk/azcore/runtime/policy_retry_test.go index f7ab18e4c556..89d46c3ac260 100644 --- a/sdk/azcore/policy_retry_test.go +++ b/sdk/azcore/runtime/policy_retry_test.go @@ -4,7 +4,7 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -package azcore +package runtime import ( "bytes" @@ -17,13 +17,17 @@ import ( "testing" "time" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/pipeline" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/shared" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/policy" + "github.com/Azure/azure-sdk-for-go/sdk/internal/errorinfo" "github.com/Azure/azure-sdk-for-go/sdk/internal/mock" ) -func testRetryOptions() *RetryOptions { - def := RetryOptions{} - def.RetryDelay = 20 * time.Millisecond - return &def +func testRetryOptions() *policy.RetryOptions { + return &policy.RetryOptions{ + RetryDelay: 20 * time.Millisecond, + } } func TestRetryPolicySuccess(t *testing.T) { @@ -74,10 +78,10 @@ func TestRetryPolicyFailOnStatusCode(t *testing.T) { if resp.StatusCode != http.StatusInternalServerError { t.Fatalf("unexpected status code: %d", resp.StatusCode) } - if r := srv.Requests(); r != defaultMaxRetries+1 { - t.Fatalf("wrong request count, got %d expected %d", r, defaultMaxRetries+1) + if r := srv.Requests(); r != shared.DefaultMaxRetries+1 { + t.Fatalf("wrong request count, got %d expected %d", r, shared.DefaultMaxRetries+1) } - if body.rcount != defaultMaxRetries { + if body.rcount != shared.DefaultMaxRetries { t.Fatalf("unexpected rewind count: %d", body.rcount) } if !body.closed { @@ -92,12 +96,12 @@ func TestRetryPolicyFailOnStatusCodeRespBodyPreserved(t *testing.T) { srv.SetResponse(mock.WithStatusCode(http.StatusInternalServerError), mock.WithBody([]byte(respBody))) // add a per-request policy that reads and restores the request body. // this is to simulate how something like httputil.DumpRequest works. - pl := NewPipeline(srv, policyFunc(func(r *Request) (*http.Response, error) { - b, err := ioutil.ReadAll(r.Body) + pl := NewPipeline(srv, pipeline.PolicyFunc(func(r *policy.Request) (*http.Response, error) { + b, err := ioutil.ReadAll(r.Raw().Body) if err != nil { t.Fatal(err) } - r.Body = ioutil.NopCloser(bytes.NewReader(b)) + r.Raw().Body = ioutil.NopCloser(bytes.NewReader(b)) return r.Next() }), NewRetryPolicy(testRetryOptions())) req, err := NewRequest(context.Background(), http.MethodGet, srv.URL()) @@ -115,10 +119,10 @@ func TestRetryPolicyFailOnStatusCodeRespBodyPreserved(t *testing.T) { if resp.StatusCode != http.StatusInternalServerError { t.Fatalf("unexpected status code: %d", resp.StatusCode) } - if r := srv.Requests(); r != defaultMaxRetries+1 { - t.Fatalf("wrong request count, got %d expected %d", r, defaultMaxRetries+1) + if r := srv.Requests(); r != shared.DefaultMaxRetries+1 { + t.Fatalf("wrong request count, got %d expected %d", r, shared.DefaultMaxRetries+1) } - if body.rcount != defaultMaxRetries { + if body.rcount != shared.DefaultMaxRetries { t.Fatalf("unexpected rewind count: %d", body.rcount) } if !body.closed { @@ -210,7 +214,7 @@ func TestRetryPolicyNoRetries(t *testing.T) { srv.AppendResponse(mock.WithStatusCode(http.StatusRequestTimeout)) srv.AppendResponse(mock.WithStatusCode(http.StatusInternalServerError)) srv.AppendResponse() - pl := NewPipeline(srv, NewRetryPolicy(&RetryOptions{MaxRetries: -1})) + pl := NewPipeline(srv, NewRetryPolicy(&policy.RetryOptions{MaxRetries: -1})) req, err := NewRequest(context.Background(), http.MethodGet, srv.URL()) if err != nil { t.Fatalf("unexpected error: %v", err) @@ -273,10 +277,10 @@ func TestRetryPolicyFailOnError(t *testing.T) { if resp != nil { t.Fatal("unexpected response") } - if r := srv.Requests(); r != defaultMaxRetries+1 { - t.Fatalf("wrong request count, got %d expected %d", r, defaultMaxRetries+1) + if r := srv.Requests(); r != shared.DefaultMaxRetries+1 { + t.Fatalf("wrong request count, got %d expected %d", r, shared.DefaultMaxRetries+1) } - if body.rcount != defaultMaxRetries { + if body.rcount != shared.DefaultMaxRetries { t.Fatalf("unexpected rewind count: %d", body.rcount) } if !body.closed { @@ -307,10 +311,10 @@ func TestRetryPolicySuccessWithRetryComplex(t *testing.T) { if resp.StatusCode != http.StatusAccepted { t.Fatalf("unexpected status code: %d", resp.StatusCode) } - if r := srv.Requests(); r != defaultMaxRetries+1 { - t.Fatalf("wrong request count, got %d expected %d", r, defaultMaxRetries+1) + if r := srv.Requests(); r != shared.DefaultMaxRetries+1 { + t.Fatalf("wrong request count, got %d expected %d", r, shared.DefaultMaxRetries+1) } - if body.rcount != defaultMaxRetries { + if body.rcount != shared.DefaultMaxRetries { t.Fatalf("unexpected rewind count: %d", body.rcount) } if !body.closed { @@ -360,7 +364,7 @@ func (f fatalError) NonRetriable() { // marker method } -var _ NonRetriableError = (*fatalError)(nil) +var _ errorinfo.NonRetriable = (*fatalError)(nil) func TestRetryPolicyIsNotRetriable(t *testing.T) { theErr := fatalError{s: "it's dead Jim"} @@ -395,7 +399,7 @@ func TestWithRetryOptions(t *testing.T) { customOptions := *defaultOptions customOptions.MaxRetries = 10 customOptions.MaxRetryDelay = 200 * time.Millisecond - retryCtx := WithRetryOptions(context.Background(), customOptions) + retryCtx := policy.WithRetryOptions(context.Background(), customOptions) req, err := NewRequest(retryCtx, http.MethodGet, srv.URL()) if err != nil { t.Fatalf("unexpected error: %v", err) @@ -437,8 +441,8 @@ func TestRetryPolicyFailOnErrorNoDownload(t *testing.T) { if resp != nil { t.Fatal("unexpected response") } - if r := srv.Requests(); r != defaultMaxRetries+1 { - t.Fatalf("wrong request count, got %d expected %d", r, defaultMaxRetries+1) + if r := srv.Requests(); r != shared.DefaultMaxRetries+1 { + t.Fatalf("wrong request count, got %d expected %d", r, shared.DefaultMaxRetries+1) } } @@ -613,7 +617,7 @@ func (r *rewindTrackingBody) Seek(offset int64, whence int) (int64, error) { // used to inject a nil response type nilRespInjector struct { - t Transporter + t policy.Transporter c int // the current request number r []int // the list of request numbers to return a nil response (one-based) } diff --git a/sdk/azcore/policy_telemetry.go b/sdk/azcore/runtime/policy_telemetry.go similarity index 54% rename from sdk/azcore/policy_telemetry.go rename to sdk/azcore/runtime/policy_telemetry.go index a3b0bc09eace..5e628e7a3257 100644 --- a/sdk/azcore/policy_telemetry.go +++ b/sdk/azcore/runtime/policy_telemetry.go @@ -4,7 +4,7 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -package azcore +package runtime import ( "bytes" @@ -13,32 +13,21 @@ import ( "os" "runtime" "strings" -) - -// TelemetryOptions configures the telemetry policy's behavior. -type TelemetryOptions struct { - // Value is a string prepended to each request's User-Agent and sent to the service. - // The service records the user-agent in logs for diagnostics and tracking of client requests. - Value string - - // ApplicationID is an application-specific identification string used in telemetry. - // It has a maximum length of 24 characters and must not contain any spaces. - ApplicationID string - // Disabled will prevent the addition of any telemetry data to the User-Agent. - Disabled bool -} + "github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/shared" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/policy" +) type telemetryPolicy struct { telemetryValue string } // NewTelemetryPolicy creates a telemetry policy object that adds telemetry information to outgoing HTTP requests. -// The format is [ ]azsdk--/ []. +// The format is [ ]azsdk-go-/ . // Pass nil to accept the default values; this is the same as passing a zero-value options. -func NewTelemetryPolicy(o *TelemetryOptions) Policy { +func NewTelemetryPolicy(mod, ver string, o *policy.TelemetryOptions) policy.Policy { if o == nil { - o = &TelemetryOptions{} + o = &policy.TelemetryOptions{} } tp := telemetryPolicy{} if o.Disabled { @@ -54,31 +43,29 @@ func NewTelemetryPolicy(o *TelemetryOptions) Policy { b.WriteString(o.ApplicationID) b.WriteRune(' ') } - // write out telemetry string - if o.Value != "" { - b.WriteString(o.Value) - b.WriteRune(' ') - } - b.WriteString(UserAgent) + b.WriteString(formatTelemetry(mod, ver)) + b.WriteRune(' ') + // inject azcore info + b.WriteString(formatTelemetry(shared.Module, shared.Version)) b.WriteRune(' ') b.WriteString(platformInfo) tp.telemetryValue = b.String() return &tp } -func (p telemetryPolicy) Do(req *Request) (*http.Response, error) { +func formatTelemetry(comp, ver string) string { + return fmt.Sprintf("azsdk-go-%s/%s", comp, ver) +} + +func (p telemetryPolicy) Do(req *policy.Request) (*http.Response, error) { if p.telemetryValue == "" { return req.Next() } // preserve the existing User-Agent string - if ua := req.Request.Header.Get(headerUserAgent); ua != "" { + if ua := req.Raw().Header.Get(shared.HeaderUserAgent); ua != "" { p.telemetryValue = fmt.Sprintf("%s %s", p.telemetryValue, ua) } - var rt requestTelemetry - if req.OperationValue(&rt) { - p.telemetryValue = fmt.Sprintf("%s %s", string(rt), p.telemetryValue) - } - req.Request.Header.Set(headerUserAgent, p.telemetryValue) + req.Raw().Header.Set(shared.HeaderUserAgent, p.telemetryValue) return req.Next() } @@ -93,6 +80,3 @@ var platformInfo = func() string { } return fmt.Sprintf("(%s; %s)", runtime.Version(), operatingSystem) }() - -// used for adding per-request telemetry -type requestTelemetry string diff --git a/sdk/azcore/policy_telemetry_test.go b/sdk/azcore/runtime/policy_telemetry_test.go similarity index 52% rename from sdk/azcore/policy_telemetry_test.go rename to sdk/azcore/runtime/policy_telemetry_test.go index 6149eadae55f..03ea6949615e 100644 --- a/sdk/azcore/policy_telemetry_test.go +++ b/sdk/azcore/runtime/policy_telemetry_test.go @@ -4,7 +4,7 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -package azcore +package runtime import ( "context" @@ -12,16 +12,18 @@ import ( "net/http" "testing" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/shared" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/policy" "github.com/Azure/azure-sdk-for-go/sdk/internal/mock" ) -var defaultTelemetry = UserAgent + " " + platformInfo +var defaultTelemetry = "azsdk-go-" + shared.Module + "/" + shared.Version + " " + platformInfo func TestPolicyTelemetryDefault(t *testing.T) { srv, close := mock.NewServer() defer close() srv.SetResponse() - pl := NewPipeline(srv, NewTelemetryPolicy(nil)) + pl := NewPipeline(srv, NewTelemetryPolicy("test", "v1.2.3", nil)) req, err := NewRequest(context.Background(), http.MethodGet, srv.URL()) if err != nil { t.Fatalf("unexpected error: %v", err) @@ -30,26 +32,7 @@ func TestPolicyTelemetryDefault(t *testing.T) { if err != nil { t.Fatalf("unexpected error: %v", err) } - if v := resp.Request.Header.Get(headerUserAgent); v != defaultTelemetry { - t.Fatalf("unexpected user agent value: %s", v) - } -} - -func TestPolicyTelemetryWithCustomInfo(t *testing.T) { - srv, close := mock.NewServer() - defer close() - srv.SetResponse() - const testValue = "azcore_test" - pl := NewPipeline(srv, NewTelemetryPolicy(&TelemetryOptions{Value: testValue})) - req, err := NewRequest(context.Background(), http.MethodGet, srv.URL()) - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - resp, err := pl.Do(req) - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - if v := resp.Request.Header.Get(headerUserAgent); v != fmt.Sprintf("%s %s", testValue, defaultTelemetry) { + if v := resp.Request.Header.Get(shared.HeaderUserAgent); v != "azsdk-go-test/v1.2.3 "+defaultTelemetry { t.Fatalf("unexpected user agent value: %s", v) } } @@ -58,18 +41,18 @@ func TestPolicyTelemetryPreserveExisting(t *testing.T) { srv, close := mock.NewServer() defer close() srv.SetResponse() - pl := NewPipeline(srv, NewTelemetryPolicy(nil)) + pl := NewPipeline(srv, NewTelemetryPolicy("test", "v1.2.3", nil)) req, err := NewRequest(context.Background(), http.MethodGet, srv.URL()) if err != nil { t.Fatalf("unexpected error: %v", err) } const otherValue = "this should stay" - req.Header.Set(headerUserAgent, otherValue) + req.Raw().Header.Set(shared.HeaderUserAgent, otherValue) resp, err := pl.Do(req) if err != nil { t.Fatalf("unexpected error: %v", err) } - if v := resp.Request.Header.Get(headerUserAgent); v != fmt.Sprintf("%s %s", defaultTelemetry, otherValue) { + if v := resp.Request.Header.Get(shared.HeaderUserAgent); v != fmt.Sprintf("%s %s", "azsdk-go-test/v1.2.3 "+defaultTelemetry, otherValue) { t.Fatalf("unexpected user agent value: %s", v) } } @@ -79,36 +62,16 @@ func TestPolicyTelemetryWithAppID(t *testing.T) { defer close() srv.SetResponse() const appID = "my_application" - pl := NewPipeline(srv, NewTelemetryPolicy(&TelemetryOptions{ApplicationID: appID})) - req, err := NewRequest(context.Background(), http.MethodGet, srv.URL()) - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - resp, err := pl.Do(req) - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - if v := resp.Request.Header.Get(headerUserAgent); v != fmt.Sprintf("%s %s", appID, defaultTelemetry) { - t.Fatalf("unexpected user agent value: %s", v) - } -} - -func TestPolicyTelemetryWithAppIDAndReqTelemetry(t *testing.T) { - srv, close := mock.NewServer() - defer close() - srv.SetResponse() - const appID = "my_application" - pl := NewPipeline(srv, NewTelemetryPolicy(&TelemetryOptions{ApplicationID: appID})) + pl := NewPipeline(srv, NewTelemetryPolicy("test", "v1.2.3", &policy.TelemetryOptions{ApplicationID: appID})) req, err := NewRequest(context.Background(), http.MethodGet, srv.URL()) if err != nil { t.Fatalf("unexpected error: %v", err) } - req.Telemetry("TestPolicyTelemetryWithAppIDAndReqTelemetry") resp, err := pl.Do(req) if err != nil { t.Fatalf("unexpected error: %v", err) } - if v := resp.Request.Header.Get(headerUserAgent); v != fmt.Sprintf("%s %s %s", "TestPolicyTelemetryWithAppIDAndReqTelemetry", appID, defaultTelemetry) { + if v := resp.Request.Header.Get(shared.HeaderUserAgent); v != fmt.Sprintf("%s %s", appID, "azsdk-go-test/v1.2.3 "+defaultTelemetry) { t.Fatalf("unexpected user agent value: %s", v) } } @@ -118,7 +81,7 @@ func TestPolicyTelemetryWithAppIDSanitized(t *testing.T) { defer close() srv.SetResponse() const appID = "This will get the spaces removed and truncated." - pl := NewPipeline(srv, NewTelemetryPolicy(&TelemetryOptions{ApplicationID: appID})) + pl := NewPipeline(srv, NewTelemetryPolicy("test", "v1.2.3", &policy.TelemetryOptions{ApplicationID: appID})) req, err := NewRequest(context.Background(), http.MethodGet, srv.URL()) if err != nil { t.Fatalf("unexpected error: %v", err) @@ -128,7 +91,7 @@ func TestPolicyTelemetryWithAppIDSanitized(t *testing.T) { t.Fatalf("unexpected error: %v", err) } const newAppID = "This/will/get/the/spaces" - if v := resp.Request.Header.Get(headerUserAgent); v != fmt.Sprintf("%s %s", newAppID, defaultTelemetry) { + if v := resp.Request.Header.Get(shared.HeaderUserAgent); v != fmt.Sprintf("%s %s", newAppID, "azsdk-go-test/v1.2.3 "+defaultTelemetry) { t.Fatalf("unexpected user agent value: %s", v) } } @@ -138,18 +101,18 @@ func TestPolicyTelemetryPreserveExistingWithAppID(t *testing.T) { defer close() srv.SetResponse() const appID = "my_application" - pl := NewPipeline(srv, NewTelemetryPolicy(&TelemetryOptions{ApplicationID: appID})) + pl := NewPipeline(srv, NewTelemetryPolicy("test", "v1.2.3", &policy.TelemetryOptions{ApplicationID: appID})) req, err := NewRequest(context.Background(), http.MethodGet, srv.URL()) if err != nil { t.Fatalf("unexpected error: %v", err) } const otherValue = "this should stay" - req.Header.Set(headerUserAgent, otherValue) + req.Raw().Header.Set(shared.HeaderUserAgent, otherValue) resp, err := pl.Do(req) if err != nil { t.Fatalf("unexpected error: %v", err) } - if v := resp.Request.Header.Get(headerUserAgent); v != fmt.Sprintf("%s %s %s", appID, defaultTelemetry, otherValue) { + if v := resp.Request.Header.Get(shared.HeaderUserAgent); v != fmt.Sprintf("%s %s %s", appID, "azsdk-go-test/v1.2.3 "+defaultTelemetry, otherValue) { t.Fatalf("unexpected user agent value: %s", v) } } @@ -159,17 +122,16 @@ func TestPolicyTelemetryDisabled(t *testing.T) { defer close() srv.SetResponse() const appID = "my_application" - pl := NewPipeline(srv, NewTelemetryPolicy(&TelemetryOptions{ApplicationID: appID, Disabled: true})) + pl := NewPipeline(srv, NewTelemetryPolicy("test", "v1.2.3", &policy.TelemetryOptions{ApplicationID: appID, Disabled: true})) req, err := NewRequest(context.Background(), http.MethodGet, srv.URL()) if err != nil { t.Fatalf("unexpected error: %v", err) } - req.Telemetry("this should be ignored") resp, err := pl.Do(req) if err != nil { t.Fatalf("unexpected error: %v", err) } - if v := resp.Request.Header.Get(headerUserAgent); v != "" { + if v := resp.Request.Header.Get(shared.HeaderUserAgent); v != "" { t.Fatalf("unexpected user agent value: %s", v) } } diff --git a/sdk/azcore/runtime/poller.go b/sdk/azcore/runtime/poller.go new file mode 100644 index 000000000000..686c04725da6 --- /dev/null +++ b/sdk/azcore/runtime/poller.go @@ -0,0 +1,70 @@ +//go:build go1.16 +// +build go1.16 + +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package runtime + +import ( + "encoding/json" + "errors" + "fmt" + "net/http" + + "github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/pipeline" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/pollers" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/pollers/loc" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/pollers/op" + "github.com/Azure/azure-sdk-for-go/sdk/internal/log" +) + +// NewPoller creates a Poller based on the provided initial response. +// pollerID - a unique identifier for an LRO, it's usually the client.Method string. +func NewPoller(pollerID string, resp *http.Response, pl pipeline.Pipeline, eu func(*http.Response) error) (*pollers.Poller, error) { + // this is a back-stop in case the swagger is incorrect (i.e. missing one or more status codes for success). + // ideally the codegen should return an error if the initial response failed and not even create a poller. + if !pollers.StatusCodeValid(resp) { + return nil, errors.New("the operation failed or was cancelled") + } + // determine the polling method + var lro pollers.Operation + var err error + // op poller must be checked first as it can also have a location header + if op.Applicable(resp) { + lro, err = op.New(resp, pollerID) + } else if loc.Applicable(resp) { + lro, err = loc.New(resp, pollerID) + } else { + lro = &pollers.NopPoller{} + } + if err != nil { + return nil, err + } + return pollers.NewPoller(lro, resp, pl, eu), nil +} + +// NewPollerFromResumeToken creates a Poller from a resume token string. +// pollerID - a unique identifier for an LRO, it's usually the client.Method string. +func NewPollerFromResumeToken(pollerID string, token string, pl pipeline.Pipeline, eu func(*http.Response) error) (*pollers.Poller, error) { + kind, err := pollers.KindFromToken(pollerID, token) + if err != nil { + return nil, err + } + // now rehydrate the poller based on the encoded poller type + var lro pollers.Operation + switch kind { + case loc.Kind: + log.Writef(log.LongRunningOperation, "Resuming %s poller.", loc.Kind) + lro = &loc.Poller{} + case op.Kind: + log.Writef(log.LongRunningOperation, "Resuming %s poller.", op.Kind) + lro = &op.Poller{} + default: + return nil, fmt.Errorf("unhandled poller type %s", kind) + } + if err = json.Unmarshal([]byte(token), lro); err != nil { + return nil, err + } + return pollers.NewPoller(lro, nil, pl, eu), nil +} diff --git a/sdk/azcore/poller_test.go b/sdk/azcore/runtime/poller_test.go similarity index 97% rename from sdk/azcore/poller_test.go rename to sdk/azcore/runtime/poller_test.go index 36da6e548048..b7d69a1afe53 100644 --- a/sdk/azcore/poller_test.go +++ b/sdk/azcore/runtime/poller_test.go @@ -4,7 +4,7 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -package azcore +package runtime import ( "context" @@ -12,9 +12,11 @@ import ( "fmt" "net/http" "net/url" + "reflect" "testing" "time" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/pollers" "github.com/Azure/azure-sdk-for-go/sdk/internal/mock" ) @@ -75,26 +77,18 @@ func TestNewPollerFromResumeTokenFail(t *testing.T) { } } -func TestOpPollerSimple(t *testing.T) { +func TestLocPollerSimple(t *testing.T) { srv, close := mock.NewServer() - srv.AppendResponse(mock.WithStatusCode(http.StatusAccepted), mock.WithBody([]byte(`{ "status": "InProgress"}`))) - srv.AppendResponse(mock.WithStatusCode(http.StatusAccepted), mock.WithBody([]byte(`{ "status": "InProgress"}`))) - srv.AppendResponse(mock.WithStatusCode(http.StatusOK), mock.WithBody([]byte(`{ "status": "Succeeded"}`))) defer close() + srv.AppendResponse(mock.WithStatusCode(http.StatusAccepted)) + srv.AppendResponse(mock.WithStatusCode(http.StatusAccepted)) + srv.AppendResponse(mock.WithStatusCode(http.StatusOK)) - reqURL, err := url.Parse(srv.URL()) - if err != nil { - t.Fatal(err) - } firstResp := &http.Response{ StatusCode: http.StatusAccepted, Header: http.Header{ - "Operation-Location": []string{srv.URL()}, - "Retry-After": []string{"1"}, - }, - Request: &http.Request{ - Method: http.MethodPut, - URL: reqURL, + "Location": []string{srv.URL()}, + "Retry-After": []string{"1"}, }, } pl := NewPipeline(srv) @@ -111,28 +105,18 @@ func TestOpPollerSimple(t *testing.T) { } } -func TestOpPollerWithWidgetPUT(t *testing.T) { +func TestLocPollerWithWidget(t *testing.T) { srv, close := mock.NewServer() - srv.AppendResponse(mock.WithStatusCode(http.StatusAccepted), mock.WithBody([]byte(`{"status": "InProgress"}`)), mock.WithHeader("Retry-After", "1")) - srv.AppendResponse(mock.WithStatusCode(http.StatusAccepted), mock.WithBody([]byte(`{"status": "InProgress"}`))) - srv.AppendResponse(mock.WithStatusCode(http.StatusOK), mock.WithBody([]byte(`{"status": "Succeeded"}`))) - // PUT and PATCH state that a final GET will happen - srv.AppendResponse(mock.WithStatusCode(http.StatusOK), mock.WithBody([]byte(`{"size": 2}`))) defer close() + srv.AppendResponse(mock.WithStatusCode(http.StatusAccepted)) + srv.AppendResponse(mock.WithStatusCode(http.StatusAccepted)) + srv.AppendResponse(mock.WithStatusCode(http.StatusOK), mock.WithBody([]byte(`{"size": 3}`))) - reqURL, err := url.Parse(srv.URL()) - if err != nil { - t.Fatal(err) - } firstResp := &http.Response{ StatusCode: http.StatusAccepted, Header: http.Header{ - "Operation-Location": []string{srv.URL()}, - "Retry-After": []string{"1"}, - }, - Request: &http.Request{ - Method: http.MethodPut, - URL: reqURL, + "Location": []string{srv.URL()}, + "Retry-After": []string{"1"}, }, } pl := NewPipeline(srv) @@ -148,34 +132,23 @@ func TestOpPollerWithWidgetPUT(t *testing.T) { if resp.StatusCode != http.StatusOK { t.Fatalf("unexpected status code %d", resp.StatusCode) } - if w.Size != 2 { + if w.Size != 3 { t.Fatalf("unexpected widget size %d", w.Size) } } -func TestOpPollerWithWidgetPOSTLocation(t *testing.T) { +func TestLocPollerCancelled(t *testing.T) { srv, close := mock.NewServer() - srv.AppendResponse(mock.WithStatusCode(http.StatusAccepted), mock.WithBody([]byte(`{"status": "InProgress"}`))) - srv.AppendResponse(mock.WithStatusCode(http.StatusAccepted), mock.WithBody([]byte(`{"status": "InProgress"}`))) - srv.AppendResponse(mock.WithStatusCode(http.StatusOK), mock.WithBody([]byte(`{"status": "Succeeded"}`))) - // POST state that a final GET will happen from the URL provided in the Location header if available - srv.AppendResponse(mock.WithStatusCode(http.StatusOK), mock.WithBody([]byte(`{"size": 2}`))) defer close() + srv.AppendResponse(mock.WithStatusCode(http.StatusAccepted)) + srv.AppendResponse(mock.WithStatusCode(http.StatusAccepted)) + srv.AppendResponse(mock.WithStatusCode(http.StatusConflict), mock.WithBody([]byte(`{"error": "cancelled"}`))) - reqURL, err := url.Parse(srv.URL()) - if err != nil { - t.Fatal(err) - } firstResp := &http.Response{ StatusCode: http.StatusAccepted, Header: http.Header{ - "Operation-Location": []string{srv.URL()}, - "Location": []string{srv.URL()}, - "Retry-After": []string{"1"}, - }, - Request: &http.Request{ - Method: http.MethodPost, - URL: reqURL, + "Location": []string{srv.URL()}, + "Retry-After": []string{"1"}, }, } pl := NewPipeline(srv) @@ -185,38 +158,32 @@ func TestOpPollerWithWidgetPOSTLocation(t *testing.T) { } var w widget resp, err := lro.PollUntilDone(context.Background(), 5*time.Millisecond, &w) - if err != nil { - t.Fatal(err) + if err == nil { + t.Fatal("unexpected nil error") } - if resp.StatusCode != http.StatusOK { - t.Fatalf("unexpected status code %d", resp.StatusCode) + if _, ok := err.(pollerError); !ok { + t.Fatal("expected pollerError") } - if w.Size != 2 { + if resp != nil { + t.Fatal("expected nil response") + } + if w.Size != 0 { t.Fatalf("unexpected widget size %d", w.Size) } } -func TestOpPollerWithWidgetPOST(t *testing.T) { +func TestLocPollerWithError(t *testing.T) { srv, close := mock.NewServer() - srv.AppendResponse(mock.WithStatusCode(http.StatusAccepted), mock.WithBody([]byte(`{"status": "InProgress"}`))) - srv.AppendResponse(mock.WithStatusCode(http.StatusAccepted), mock.WithBody([]byte(`{"status": "InProgress"}`))) - // POST with no location header means the success response returns the model - srv.AppendResponse(mock.WithStatusCode(http.StatusOK), mock.WithBody([]byte(`{"status": "Succeeded", "size": 2}`))) defer close() + srv.AppendResponse(mock.WithStatusCode(http.StatusAccepted)) + srv.AppendResponse(mock.WithStatusCode(http.StatusAccepted)) + srv.AppendError(errors.New("oops")) - reqURL, err := url.Parse(srv.URL()) - if err != nil { - t.Fatal(err) - } firstResp := &http.Response{ StatusCode: http.StatusAccepted, Header: http.Header{ - "Operation-Location": []string{srv.URL()}, - "Retry-After": []string{"1"}, - }, - Request: &http.Request{ - Method: http.MethodPost, - URL: reqURL, + "Location": []string{srv.URL()}, + "Retry-After": []string{"1"}, }, } pl := NewPipeline(srv) @@ -226,41 +193,32 @@ func TestOpPollerWithWidgetPOST(t *testing.T) { } var w widget resp, err := lro.PollUntilDone(context.Background(), 5*time.Millisecond, &w) - if err != nil { - t.Fatal(err) + if err == nil { + t.Fatal("unexpected nil error") } - if resp.StatusCode != http.StatusOK { - t.Fatalf("unexpected status code %d", resp.StatusCode) + if e := err.Error(); e != "oops" { + t.Fatalf("expected error %s", e) } - if w.Size != 2 { + if resp != nil { + t.Fatal("expected nil response") + } + if w.Size != 0 { t.Fatalf("unexpected widget size %d", w.Size) } } -func TestOpPollerWithWidgetResourceLocation(t *testing.T) { +func TestLocPollerWithResumeToken(t *testing.T) { srv, close := mock.NewServer() - srv.AppendResponse(mock.WithStatusCode(http.StatusAccepted), mock.WithBody([]byte(`{"status": "InProgress"}`))) - srv.AppendResponse(mock.WithStatusCode(http.StatusAccepted), mock.WithBody([]byte(`{"status": "InProgress"}`))) - srv.AppendResponse(mock.WithStatusCode(http.StatusOK), mock.WithBody([]byte( - fmt.Sprintf(`{"status": "Succeeded", "resourceLocation": "%s"}`, srv.URL())))) - // final GET will happen from the URL provided in the resourceLocation - srv.AppendResponse(mock.WithStatusCode(http.StatusOK), mock.WithBody([]byte(`{"size": 2}`))) + srv.AppendResponse(mock.WithStatusCode(http.StatusAccepted)) + srv.AppendResponse(mock.WithStatusCode(http.StatusAccepted)) + srv.AppendResponse(mock.WithStatusCode(http.StatusOK)) defer close() - reqURL, err := url.Parse(srv.URL()) - if err != nil { - t.Fatal(err) - } firstResp := &http.Response{ StatusCode: http.StatusAccepted, Header: http.Header{ - "Operation-Location": []string{srv.URL()}, - "Location": []string{srv.URL()}, - "Retry-After": []string{"1"}, - }, - Request: &http.Request{ - Method: http.MethodPatch, - URL: reqURL, + "Location": []string{srv.URL()}, + "Retry-After": []string{"1"}, }, } pl := NewPipeline(srv) @@ -268,20 +226,70 @@ func TestOpPollerWithWidgetResourceLocation(t *testing.T) { if err != nil { t.Fatal(err) } - var w widget - resp, err := lro.PollUntilDone(context.Background(), 5*time.Millisecond, &w) + resp, err := lro.Poll(context.Background()) + if err != nil { + t.Fatal(err) + } + if resp.StatusCode != http.StatusAccepted { + t.Fatalf("unexpected status code %d", resp.StatusCode) + } + if lro.Done() { + t.Fatal("poller shouldn't be done yet") + } + resp, err = lro.FinalResponse(context.Background(), nil) + if err == nil { + t.Fatal("unexpected nil error") + } + if resp != nil { + t.Fatal("expected nil response") + } + tk, err := lro.ResumeToken() + if err != nil { + t.Fatal(err) + } + lro, err = NewPollerFromResumeToken("fake.poller", tk, pl, errUnmarshall) + if err != nil { + t.Fatal(err) + } + resp, err = lro.PollUntilDone(context.Background(), 5*time.Millisecond, nil) if err != nil { t.Fatal(err) } if resp.StatusCode != http.StatusOK { t.Fatalf("unexpected status code %d", resp.StatusCode) } - if w.Size != 2 { - t.Fatalf("unexpected widget size %d", w.Size) +} + +func TestLocPollerWithTimeout(t *testing.T) { + srv, close := mock.NewServer() + srv.AppendResponse(mock.WithStatusCode(http.StatusAccepted)) + srv.AppendResponse(mock.WithStatusCode(http.StatusAccepted)) + srv.AppendResponse(mock.WithSlowResponse(2 * time.Second)) + defer close() + + firstResp := &http.Response{ + StatusCode: http.StatusAccepted, + Header: http.Header{ + "Location": []string{srv.URL()}, + }, + } + pl := NewPipeline(srv) + lro, err := NewPoller("fake.poller", firstResp, pl, errUnmarshall) + if err != nil { + t.Fatal(err) + } + ctx, cancel := context.WithTimeout(context.Background(), 50*time.Millisecond) + resp, err := lro.PollUntilDone(ctx, 5*time.Millisecond, nil) + cancel() + if err == nil { + t.Fatal("unexpected nil error") + } + if resp != nil { + t.Fatal("expected nil response") } } -func TestOpPollerWithResumeToken(t *testing.T) { +func TestOpPollerSimple(t *testing.T) { srv, close := mock.NewServer() srv.AppendResponse(mock.WithStatusCode(http.StatusAccepted), mock.WithBody([]byte(`{ "status": "InProgress"}`))) srv.AppendResponse(mock.WithStatusCode(http.StatusAccepted), mock.WithBody([]byte(`{ "status": "InProgress"}`))) @@ -293,13 +301,14 @@ func TestOpPollerWithResumeToken(t *testing.T) { t.Fatal(err) } firstResp := &http.Response{ + Body: http.NoBody, StatusCode: http.StatusAccepted, Header: http.Header{ "Operation-Location": []string{srv.URL()}, "Retry-After": []string{"1"}, }, Request: &http.Request{ - Method: http.MethodPut, + Method: http.MethodDelete, URL: reqURL, }, } @@ -308,32 +317,7 @@ func TestOpPollerWithResumeToken(t *testing.T) { if err != nil { t.Fatal(err) } - resp, err := lro.Poll(context.Background()) - if err != nil { - t.Fatal(err) - } - if resp.StatusCode != http.StatusAccepted { - t.Fatalf("unexpected status code %d", resp.StatusCode) - } - if lro.Done() { - t.Fatal("poller shouldn't be done yet") - } - resp, err = lro.FinalResponse(context.Background(), nil) - if err == nil { - t.Fatal("unexpected nil error") - } - if resp != nil { - t.Fatal("expected nil response") - } - tk, err := lro.ResumeToken() - if err != nil { - t.Fatal(err) - } - lro, err = NewPollerFromResumeToken("fake.poller", tk, pl, errUnmarshall) - if err != nil { - t.Fatal(err) - } - resp, err = lro.PollUntilDone(context.Background(), 5*time.Millisecond, nil) + resp, err := lro.PollUntilDone(context.Background(), 5*time.Millisecond, nil) if err != nil { t.Fatal(err) } @@ -342,18 +326,29 @@ func TestOpPollerWithResumeToken(t *testing.T) { } } -func TestLocPollerSimple(t *testing.T) { +func TestOpPollerWithWidgetPUT(t *testing.T) { srv, close := mock.NewServer() + srv.AppendResponse(mock.WithStatusCode(http.StatusAccepted), mock.WithBody([]byte(`{"status": "InProgress"}`)), mock.WithHeader("Retry-After", "1")) + srv.AppendResponse(mock.WithStatusCode(http.StatusAccepted), mock.WithBody([]byte(`{"status": "InProgress"}`))) + srv.AppendResponse(mock.WithStatusCode(http.StatusOK), mock.WithBody([]byte(`{"status": "Succeeded"}`))) + // PUT and PATCH state that a final GET will happen + srv.AppendResponse(mock.WithStatusCode(http.StatusOK), mock.WithBody([]byte(`{"size": 2}`))) defer close() - srv.AppendResponse(mock.WithStatusCode(http.StatusAccepted)) - srv.AppendResponse(mock.WithStatusCode(http.StatusAccepted)) - srv.AppendResponse(mock.WithStatusCode(http.StatusOK)) + reqURL, err := url.Parse(srv.URL()) + if err != nil { + t.Fatal(err) + } firstResp := &http.Response{ + Body: http.NoBody, StatusCode: http.StatusAccepted, Header: http.Header{ - "Location": []string{srv.URL()}, - "Retry-After": []string{"1"}, + "Operation-Location": []string{srv.URL()}, + "Retry-After": []string{"1"}, + }, + Request: &http.Request{ + Method: http.MethodPut, + URL: reqURL, }, } pl := NewPipeline(srv) @@ -361,27 +356,43 @@ func TestLocPollerSimple(t *testing.T) { if err != nil { t.Fatal(err) } - resp, err := lro.PollUntilDone(context.Background(), 5*time.Millisecond, nil) + var w widget + resp, err := lro.PollUntilDone(context.Background(), 5*time.Millisecond, &w) if err != nil { t.Fatal(err) } if resp.StatusCode != http.StatusOK { t.Fatalf("unexpected status code %d", resp.StatusCode) } + if w.Size != 2 { + t.Fatalf("unexpected widget size %d", w.Size) + } } -func TestLocPollerWithWidget(t *testing.T) { +func TestOpPollerWithWidgetPOSTLocation(t *testing.T) { srv, close := mock.NewServer() + srv.AppendResponse(mock.WithStatusCode(http.StatusAccepted), mock.WithBody([]byte(`{"status": "InProgress"}`))) + srv.AppendResponse(mock.WithStatusCode(http.StatusAccepted), mock.WithBody([]byte(`{"status": "InProgress"}`))) + srv.AppendResponse(mock.WithStatusCode(http.StatusOK), mock.WithBody([]byte(`{"status": "Succeeded"}`))) + // POST state that a final GET will happen from the URL provided in the Location header if available + srv.AppendResponse(mock.WithStatusCode(http.StatusOK), mock.WithBody([]byte(`{"size": 2}`))) defer close() - srv.AppendResponse(mock.WithStatusCode(http.StatusAccepted)) - srv.AppendResponse(mock.WithStatusCode(http.StatusAccepted)) - srv.AppendResponse(mock.WithStatusCode(http.StatusOK), mock.WithBody([]byte(`{"size": 3}`))) + reqURL, err := url.Parse(srv.URL()) + if err != nil { + t.Fatal(err) + } firstResp := &http.Response{ + Body: http.NoBody, StatusCode: http.StatusAccepted, Header: http.Header{ - "Location": []string{srv.URL()}, - "Retry-After": []string{"1"}, + "Operation-Location": []string{srv.URL()}, + "Location": []string{srv.URL()}, + "Retry-After": []string{"1"}, + }, + Request: &http.Request{ + Method: http.MethodPost, + URL: reqURL, }, } pl := NewPipeline(srv) @@ -397,23 +408,33 @@ func TestLocPollerWithWidget(t *testing.T) { if resp.StatusCode != http.StatusOK { t.Fatalf("unexpected status code %d", resp.StatusCode) } - if w.Size != 3 { + if w.Size != 2 { t.Fatalf("unexpected widget size %d", w.Size) } } -func TestLocPollerCancelled(t *testing.T) { +func TestOpPollerWithWidgetPOST(t *testing.T) { srv, close := mock.NewServer() + srv.AppendResponse(mock.WithStatusCode(http.StatusAccepted), mock.WithBody([]byte(`{"status": "InProgress"}`))) + srv.AppendResponse(mock.WithStatusCode(http.StatusAccepted), mock.WithBody([]byte(`{"status": "InProgress"}`))) + // POST with no location header means the success response returns the model + srv.AppendResponse(mock.WithStatusCode(http.StatusOK), mock.WithBody([]byte(`{"status": "Succeeded", "size": 2}`))) defer close() - srv.AppendResponse(mock.WithStatusCode(http.StatusAccepted)) - srv.AppendResponse(mock.WithStatusCode(http.StatusAccepted)) - srv.AppendResponse(mock.WithStatusCode(http.StatusConflict), mock.WithBody([]byte(`{"error": "cancelled"}`))) + reqURL, err := url.Parse(srv.URL()) + if err != nil { + t.Fatal(err) + } firstResp := &http.Response{ + Body: http.NoBody, StatusCode: http.StatusAccepted, Header: http.Header{ - "Location": []string{srv.URL()}, - "Retry-After": []string{"1"}, + "Operation-Location": []string{srv.URL()}, + "Retry-After": []string{"1"}, + }, + Request: &http.Request{ + Method: http.MethodPost, + URL: reqURL, }, } pl := NewPipeline(srv) @@ -423,32 +444,42 @@ func TestLocPollerCancelled(t *testing.T) { } var w widget resp, err := lro.PollUntilDone(context.Background(), 5*time.Millisecond, &w) - if err == nil { - t.Fatal("unexpected nil error") - } - if _, ok := err.(pollerError); !ok { - t.Fatal("expected pollerError") + if err != nil { + t.Fatal(err) } - if resp != nil { - t.Fatal("expected nil response") + if resp.StatusCode != http.StatusOK { + t.Fatalf("unexpected status code %d", resp.StatusCode) } - if w.Size != 0 { + if w.Size != 2 { t.Fatalf("unexpected widget size %d", w.Size) } } -func TestLocPollerWithError(t *testing.T) { +func TestOpPollerWithWidgetResourceLocation(t *testing.T) { srv, close := mock.NewServer() + srv.AppendResponse(mock.WithStatusCode(http.StatusAccepted), mock.WithBody([]byte(`{"status": "InProgress"}`))) + srv.AppendResponse(mock.WithStatusCode(http.StatusAccepted), mock.WithBody([]byte(`{"status": "InProgress"}`))) + srv.AppendResponse(mock.WithStatusCode(http.StatusOK), mock.WithBody([]byte( + fmt.Sprintf(`{"status": "Succeeded", "resourceLocation": "%s"}`, srv.URL())))) + // final GET will happen from the URL provided in the resourceLocation + srv.AppendResponse(mock.WithStatusCode(http.StatusOK), mock.WithBody([]byte(`{"size": 2}`))) defer close() - srv.AppendResponse(mock.WithStatusCode(http.StatusAccepted)) - srv.AppendResponse(mock.WithStatusCode(http.StatusAccepted)) - srv.AppendError(errors.New("oops")) + reqURL, err := url.Parse(srv.URL()) + if err != nil { + t.Fatal(err) + } firstResp := &http.Response{ + Body: http.NoBody, StatusCode: http.StatusAccepted, Header: http.Header{ - "Location": []string{srv.URL()}, - "Retry-After": []string{"1"}, + "Operation-Location": []string{srv.URL()}, + "Location": []string{srv.URL()}, + "Retry-After": []string{"1"}, + }, + Request: &http.Request{ + Method: http.MethodPatch, + URL: reqURL, }, } pl := NewPipeline(srv) @@ -458,32 +489,38 @@ func TestLocPollerWithError(t *testing.T) { } var w widget resp, err := lro.PollUntilDone(context.Background(), 5*time.Millisecond, &w) - if err == nil { - t.Fatal("unexpected nil error") - } - if e := err.Error(); e != "oops" { - t.Fatalf("expected error %s", e) + if err != nil { + t.Fatal(err) } - if resp != nil { - t.Fatal("expected nil response") + if resp.StatusCode != http.StatusOK { + t.Fatalf("unexpected status code %d", resp.StatusCode) } - if w.Size != 0 { + if w.Size != 2 { t.Fatalf("unexpected widget size %d", w.Size) } } -func TestLocPollerWithResumeToken(t *testing.T) { +func TestOpPollerWithResumeToken(t *testing.T) { srv, close := mock.NewServer() - srv.AppendResponse(mock.WithStatusCode(http.StatusAccepted)) - srv.AppendResponse(mock.WithStatusCode(http.StatusAccepted)) - srv.AppendResponse(mock.WithStatusCode(http.StatusOK)) + srv.AppendResponse(mock.WithStatusCode(http.StatusAccepted), mock.WithBody([]byte(`{ "status": "InProgress"}`))) + srv.AppendResponse(mock.WithStatusCode(http.StatusAccepted), mock.WithBody([]byte(`{ "status": "InProgress"}`))) + srv.AppendResponse(mock.WithStatusCode(http.StatusOK), mock.WithBody([]byte(`{ "status": "Succeeded"}`))) defer close() + reqURL, err := url.Parse(srv.URL()) + if err != nil { + t.Fatal(err) + } firstResp := &http.Response{ + Body: http.NoBody, StatusCode: http.StatusAccepted, Header: http.Header{ - "Location": []string{srv.URL()}, - "Retry-After": []string{"1"}, + "Operation-Location": []string{srv.URL()}, + "Retry-After": []string{"1"}, + }, + Request: &http.Request{ + Method: http.MethodDelete, + URL: reqURL, }, } pl := NewPipeline(srv) @@ -525,35 +562,6 @@ func TestLocPollerWithResumeToken(t *testing.T) { } } -func TestLocPollerWithTimeout(t *testing.T) { - srv, close := mock.NewServer() - srv.AppendResponse(mock.WithStatusCode(http.StatusAccepted)) - srv.AppendResponse(mock.WithStatusCode(http.StatusAccepted)) - srv.AppendResponse(mock.WithSlowResponse(2 * time.Second)) - defer close() - - firstResp := &http.Response{ - StatusCode: http.StatusAccepted, - Header: http.Header{ - "Location": []string{srv.URL()}, - }, - } - pl := NewPipeline(srv) - lro, err := NewPoller("fake.poller", firstResp, pl, errUnmarshall) - if err != nil { - t.Fatal(err) - } - ctx, cancel := context.WithTimeout(context.Background(), 50*time.Millisecond) - resp, err := lro.PollUntilDone(ctx, 5*time.Millisecond, nil) - cancel() - if err == nil { - t.Fatal("unexpected nil error") - } - if resp != nil { - t.Fatal("expected nil response") - } -} - func TestNopPoller(t *testing.T) { firstResp := &http.Response{ StatusCode: http.StatusOK, @@ -563,8 +571,8 @@ func TestNopPoller(t *testing.T) { if err != nil { t.Fatal(err) } - if _, ok := lro.lro.(*nopPoller); !ok { - t.Fatalf("unexpected poller type %T", lro.lro) + if pt := pollers.PollerType(lro); pt != reflect.TypeOf(&pollers.NopPoller{}) { + t.Fatalf("unexpected poller type %s", pt.String()) } if !lro.Done() { t.Fatal("expected Done() for nopPoller") diff --git a/sdk/azcore/runtime/request.go b/sdk/azcore/runtime/request.go new file mode 100644 index 000000000000..d72b68791c4e --- /dev/null +++ b/sdk/azcore/runtime/request.go @@ -0,0 +1,228 @@ +//go:build go1.16 +// +build go1.16 + +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package runtime + +import ( + "bytes" + "context" + "encoding/base64" + "encoding/json" + "encoding/xml" + "fmt" + "io" + "mime/multipart" + "reflect" + "strings" + + "github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/pipeline" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/shared" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/policy" +) + +// Pipeline represents a primitive for sending HTTP requests and receiving responses. +// Its behavior can be extended by specifying policies during construction. +type Pipeline = pipeline.Pipeline + +// Base64Encoding is usesd to specify which base-64 encoder/decoder to use when +// encoding/decoding a slice of bytes to/from a string. +type Base64Encoding int + +const ( + // Base64StdFormat uses base64.StdEncoding for encoding and decoding payloads. + Base64StdFormat Base64Encoding = 0 + + // Base64URLFormat uses base64.RawURLEncoding for encoding and decoding payloads. + Base64URLFormat Base64Encoding = 1 +) + +// NewRequest creates a new policy.Request with the specified input. +func NewRequest(ctx context.Context, httpMethod string, endpoint string) (*pipeline.Request, error) { + return pipeline.NewRequest(ctx, httpMethod, endpoint) +} + +// NewPipeline creates a new Pipeline object from the specified Transport and Policies. +// If no transport is provided then the default *http.Client transport will be used. +func NewPipeline(transport pipeline.Transporter, policies ...pipeline.Policy) pipeline.Pipeline { + if transport == nil { + transport = defaultHTTPClient + } + // transport policy must always be the last in the slice + policies = append(policies, pipeline.PolicyFunc(httpHeaderPolicy), pipeline.PolicyFunc(bodyDownloadPolicy)) + return pipeline.NewPipeline(transport, policies...) +} + +// JoinPaths concatenates multiple URL path segments into one path, +// inserting path separation characters as required. JoinPaths will preserve +// query parameters in the root path +func JoinPaths(root string, paths ...string) string { + if len(paths) == 0 { + return root + } + + qps := "" + if strings.Contains(root, "?") { + splitPath := strings.Split(root, "?") + root, qps = splitPath[0], splitPath[1] + } + + for i := 0; i < len(paths); i++ { + root = strings.TrimRight(root, "/") + paths[i] = strings.TrimLeft(paths[i], "/") + root += "/" + paths[i] + } + + if qps != "" { + if !strings.HasSuffix(root, "/") { + root += "/" + } + return root + "?" + qps + } + return root +} + +// EncodeByteArray will base-64 encode the byte slice v. +func EncodeByteArray(v []byte, format Base64Encoding) string { + if format == Base64URLFormat { + return base64.RawURLEncoding.EncodeToString(v) + } + return base64.StdEncoding.EncodeToString(v) +} + +// MarshalAsByteArray will base-64 encode the byte slice v, then calls SetBody. +// The encoded value is treated as a JSON string. +func MarshalAsByteArray(req *policy.Request, v []byte, format Base64Encoding) error { + // send as a JSON string + encode := fmt.Sprintf("\"%s\"", EncodeByteArray(v, format)) + return req.SetBody(shared.NopCloser(strings.NewReader(encode)), shared.ContentTypeAppJSON) +} + +// MarshalAsJSON calls json.Marshal() to get the JSON encoding of v then calls SetBody. +func MarshalAsJSON(req *policy.Request, v interface{}) error { + v = cloneWithoutReadOnlyFields(v) + b, err := json.Marshal(v) + if err != nil { + return fmt.Errorf("error marshalling type %T: %s", v, err) + } + return req.SetBody(shared.NopCloser(bytes.NewReader(b)), shared.ContentTypeAppJSON) +} + +// MarshalAsXML calls xml.Marshal() to get the XML encoding of v then calls SetBody. +func MarshalAsXML(req *policy.Request, v interface{}) error { + b, err := xml.Marshal(v) + if err != nil { + return fmt.Errorf("error marshalling type %T: %s", v, err) + } + return req.SetBody(shared.NopCloser(bytes.NewReader(b)), shared.ContentTypeAppXML) +} + +// SetMultipartFormData writes the specified keys/values as multi-part form +// fields with the specified value. File content must be specified as a ReadSeekCloser. +// All other values are treated as string values. +func SetMultipartFormData(req *policy.Request, formData map[string]interface{}) error { + body := bytes.Buffer{} + writer := multipart.NewWriter(&body) + for k, v := range formData { + if rsc, ok := v.(io.ReadSeekCloser); ok { + // this is the body to upload, the key is its file name + fd, err := writer.CreateFormFile(k, k) + if err != nil { + return err + } + // copy the data to the form file + if _, err = io.Copy(fd, rsc); err != nil { + return err + } + continue + } + // ensure the value is in string format + s, ok := v.(string) + if !ok { + s = fmt.Sprintf("%v", v) + } + if err := writer.WriteField(k, s); err != nil { + return err + } + } + if err := writer.Close(); err != nil { + return err + } + return req.SetBody(shared.NopCloser(bytes.NewReader(body.Bytes())), writer.FormDataContentType()) +} + +// returns a clone of the object graph pointed to by v, omitting values of all read-only +// fields. if there are no read-only fields in the object graph, no clone is created. +func cloneWithoutReadOnlyFields(v interface{}) interface{} { + val := reflect.Indirect(reflect.ValueOf(v)) + if val.Kind() != reflect.Struct { + // not a struct, skip + return v + } + // first walk the graph to find any R/O fields. + // if there aren't any, skip cloning the graph. + if !recursiveFindReadOnlyField(val) { + return v + } + return recursiveCloneWithoutReadOnlyFields(val) +} + +// returns true if any field in the object graph of val contains the `azure:"ro"` tag value +func recursiveFindReadOnlyField(val reflect.Value) bool { + t := val.Type() + // iterate over the fields, looking for the "azure" tag. + for i := 0; i < t.NumField(); i++ { + field := t.Field(i) + aztag := field.Tag.Get("azure") + if azureTagIsReadOnly(aztag) { + return true + } else if reflect.Indirect(val.Field(i)).Kind() == reflect.Struct && recursiveFindReadOnlyField(reflect.Indirect(val.Field(i))) { + return true + } + } + return false +} + +// clones the object graph of val. all non-R/O properties are copied to the clone +func recursiveCloneWithoutReadOnlyFields(val reflect.Value) interface{} { + clone := reflect.New(val.Type()) + t := val.Type() + // iterate over the fields, looking for the "azure" tag. + for i := 0; i < t.NumField(); i++ { + field := t.Field(i) + aztag := field.Tag.Get("azure") + if azureTagIsReadOnly(aztag) { + // omit from payload + } else if reflect.Indirect(val.Field(i)).Kind() == reflect.Struct { + // recursive case + v := recursiveCloneWithoutReadOnlyFields(reflect.Indirect(val.Field(i))) + if t.Field(i).Anonymous { + // NOTE: this does not handle the case of embedded fields of unexported struct types. + // this should be ok as we don't generate any code like this at present + reflect.Indirect(clone).Field(i).Set(reflect.Indirect(reflect.ValueOf(v))) + } else { + reflect.Indirect(clone).Field(i).Set(reflect.ValueOf(v)) + } + } else { + // no azure RO tag, non-recursive case, include in payload + reflect.Indirect(clone).Field(i).Set(val.Field(i)) + } + } + return clone.Interface() +} + +// returns true if the "azure" tag contains the option "ro" +func azureTagIsReadOnly(tag string) bool { + if tag == "" { + return false + } + parts := strings.Split(tag, ",") + for _, part := range parts { + if part == "ro" { + return true + } + } + return false +} diff --git a/sdk/azcore/request_test.go b/sdk/azcore/runtime/request_test.go similarity index 82% rename from sdk/azcore/request_test.go rename to sdk/azcore/runtime/request_test.go index 98c80e7c795e..aa1bb59a14a9 100644 --- a/sdk/azcore/request_test.go +++ b/sdk/azcore/runtime/request_test.go @@ -4,7 +4,7 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -package azcore +package runtime import ( "bytes" @@ -20,6 +20,8 @@ import ( "strings" "testing" "unsafe" + + "github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/shared" ) type testJSON struct { @@ -37,17 +39,17 @@ func TestRequestMarshalXML(t *testing.T) { if err != nil { t.Fatal(err) } - err = req.MarshalAsXML(testXML{SomeInt: 1, SomeString: "s"}) + err = MarshalAsXML(req, testXML{SomeInt: 1, SomeString: "s"}) if err != nil { t.Fatalf("marshal failure: %v", err) } - if ct := req.Header.Get(headerContentType); ct != contentTypeAppXML { - t.Fatalf("unexpected content type, got %s wanted %s", ct, contentTypeAppXML) + if ct := req.Raw().Header.Get(shared.HeaderContentType); ct != shared.ContentTypeAppXML { + t.Fatalf("unexpected content type, got %s wanted %s", ct, shared.ContentTypeAppXML) } - if req.Body == nil { + if req.Raw().Body == nil { t.Fatal("unexpected nil request body") } - if req.ContentLength == 0 { + if req.Raw().ContentLength == 0 { t.Fatal("unexpected zero content length") } } @@ -71,17 +73,17 @@ func TestRequestMarshalJSON(t *testing.T) { if err != nil { t.Fatal(err) } - err = req.MarshalAsJSON(testJSON{SomeInt: 1, SomeString: "s"}) + err = MarshalAsJSON(req, testJSON{SomeInt: 1, SomeString: "s"}) if err != nil { t.Fatalf("marshal failure: %v", err) } - if ct := req.Header.Get(headerContentType); ct != contentTypeAppJSON { - t.Fatalf("unexpected content type, got %s wanted %s", ct, contentTypeAppJSON) + if ct := req.Raw().Header.Get(shared.HeaderContentType); ct != shared.ContentTypeAppJSON { + t.Fatalf("unexpected content type, got %s wanted %s", ct, shared.ContentTypeAppJSON) } - if req.Body == nil { + if req.Raw().Body == nil { t.Fatal("unexpected nil request body") } - if req.ContentLength == 0 { + if req.Raw().ContentLength == 0 { t.Fatal("unexpected zero content length") } } @@ -92,20 +94,20 @@ func TestRequestMarshalAsByteArrayURLFormat(t *testing.T) { t.Fatal(err) } const payload = "a string that gets encoded with base64url" - err = req.MarshalAsByteArray([]byte(payload), Base64URLFormat) + err = MarshalAsByteArray(req, []byte(payload), Base64URLFormat) if err != nil { t.Fatalf("marshal failure: %v", err) } - if ct := req.Header.Get(headerContentType); ct != contentTypeAppJSON { - t.Fatalf("unexpected content type, got %s wanted %s", ct, contentTypeAppJSON) + if ct := req.Raw().Header.Get(shared.HeaderContentType); ct != shared.ContentTypeAppJSON { + t.Fatalf("unexpected content type, got %s wanted %s", ct, shared.ContentTypeAppJSON) } - if req.Body == nil { + if req.Raw().Body == nil { t.Fatal("unexpected nil request body") } - if req.ContentLength == 0 { + if req.Raw().ContentLength == 0 { t.Fatal("unexpected zero content length") } - b, err := ioutil.ReadAll(req.Body) + b, err := ioutil.ReadAll(req.Raw().Body) if err != nil { t.Fatal(err) } @@ -120,20 +122,20 @@ func TestRequestMarshalAsByteArrayStdFormat(t *testing.T) { t.Fatal(err) } const payload = "a string that gets encoded with base64url" - err = req.MarshalAsByteArray([]byte(payload), Base64StdFormat) + err = MarshalAsByteArray(req, []byte(payload), Base64StdFormat) if err != nil { t.Fatalf("marshal failure: %v", err) } - if ct := req.Header.Get(headerContentType); ct != contentTypeAppJSON { - t.Fatalf("unexpected content type, got %s wanted %s", ct, contentTypeAppJSON) + if ct := req.Raw().Header.Get(shared.HeaderContentType); ct != shared.ContentTypeAppJSON { + t.Fatalf("unexpected content type, got %s wanted %s", ct, shared.ContentTypeAppJSON) } - if req.Body == nil { + if req.Raw().Body == nil { t.Fatal("unexpected nil request body") } - if req.ContentLength == 0 { + if req.Raw().ContentLength == 0 { t.Fatal("unexpected zero content length") } - b, err := ioutil.ReadAll(req.Body) + b, err := ioutil.ReadAll(req.Raw().Body) if err != nil { t.Fatal(err) } @@ -337,11 +339,11 @@ func TestCloneWithoutReadOnlyFieldsEndToEnd(t *testing.T) { ID: &id, Name: &name, } - err = req.MarshalAsJSON(nro) + err = MarshalAsJSON(req, nro) if err != nil { t.Fatal(err) } - b, err := ioutil.ReadAll(req.Body) + b, err := ioutil.ReadAll(req.Raw().Body) if err != nil { t.Fatal(err) } @@ -476,36 +478,12 @@ func TestRequestSetBodyContentLengthHeader(t *testing.T) { for i := 0; i < buffLen; i++ { buff[i] = 1 } - err = req.SetBody(NopCloser(bytes.NewReader(buff)), "application/octet-stream") + err = req.SetBody(shared.NopCloser(bytes.NewReader(buff)), "application/octet-stream") if err != nil { t.Fatal(err) } - if req.Header.Get(headerContentLength) != strconv.FormatInt(buffLen, 10) { - t.Fatalf("expected content-length %d, got %s", buffLen, req.Header.Get(headerContentLength)) - } -} - -func TestNewRequestFail(t *testing.T) { - req, err := NewRequest(context.Background(), http.MethodOptions, "://test.contoso.com/") - if err == nil { - t.Fatal("unexpected nil error") - } - if req != nil { - t.Fatal("unexpected request") - } - req, err = NewRequest(context.Background(), http.MethodPatch, "/missing/the/host") - if err == nil { - t.Fatal("unexpected nil error") - } - if req != nil { - t.Fatal("unexpected request") - } - req, err = NewRequest(context.Background(), http.MethodPatch, "mailto://nobody.contoso.com") - if err == nil { - t.Fatal("unexpected nil error") - } - if req != nil { - t.Fatal("unexpected request") + if req.Raw().Header.Get(shared.HeaderContentLength) != strconv.FormatInt(buffLen, 10) { + t.Fatalf("expected content-length %d, got %s", buffLen, req.Raw().Header.Get(shared.HeaderContentLength)) } } @@ -529,7 +507,7 @@ func TestRequestValidFail(t *testing.T) { if err != nil { t.Fatal(err) } - req.Header.Add("inval d", "header") + req.Raw().Header.Add("inval d", "header") p := NewPipeline(nil) resp, err := p.Do(req) if err == nil { @@ -538,9 +516,9 @@ func TestRequestValidFail(t *testing.T) { if resp != nil { t.Fatal("unexpected response") } - req.Header = http.Header{} + req.Raw().Header = http.Header{} // the string "null\0" - req.Header.Add("invalid", string([]byte{0x6e, 0x75, 0x6c, 0x6c, 0x0})) + req.Raw().Header.Add("invalid", string([]byte{0x6e, 0x75, 0x6c, 0x6c, 0x0})) resp, err = p.Do(req) if err == nil { t.Fatal("unexpected nil error") @@ -555,22 +533,22 @@ func TestSetMultipartFormData(t *testing.T) { if err != nil { t.Fatal(err) } - err = req.SetMultipartFormData(map[string]interface{}{ + err = SetMultipartFormData(req, map[string]interface{}{ "string": "value", "int": 1, - "data": NopCloser(strings.NewReader("some data")), + "data": shared.NopCloser(strings.NewReader("some data")), }) if err != nil { t.Fatal(err) } - mt, params, err := mime.ParseMediaType(req.Header.Get(headerContentType)) + mt, params, err := mime.ParseMediaType(req.Raw().Header.Get(shared.HeaderContentType)) if err != nil { t.Fatal(err) } if mt != "multipart/form-data" { t.Fatalf("unexpected media type %s", mt) } - reader := multipart.NewReader(req.Body, params["boundary"]) + reader := multipart.NewReader(req.Raw().Body, params["boundary"]) for { part, err := reader.NextPart() if err == io.EOF { diff --git a/sdk/azcore/response.go b/sdk/azcore/runtime/response.go similarity index 76% rename from sdk/azcore/response.go rename to sdk/azcore/runtime/response.go index 59cb195c417f..c0a990e8aa19 100644 --- a/sdk/azcore/response.go +++ b/sdk/azcore/runtime/response.go @@ -4,7 +4,7 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -package azcore +package runtime import ( "bytes" @@ -16,9 +16,10 @@ import ( "io/ioutil" "net/http" "sort" - "strconv" "strings" - "time" + + "github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/shared" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/policy" ) // Payload reads and returns the response body or an error. @@ -39,15 +40,7 @@ func Payload(resp *http.Response) ([]byte, error) { // HasStatusCode returns true if the Response's status code is one of the specified values. func HasStatusCode(resp *http.Response, statusCodes ...int) bool { - if resp == nil { - return false - } - for _, sc := range statusCodes { - if resp.StatusCode == sc { - return true - } - } - return false + return shared.HasStatusCode(resp, statusCodes...) } // UnmarshalAsByteArray will base-64 decode the received payload and place the result into the value pointed to by v. @@ -123,47 +116,6 @@ func removeBOM(resp *http.Response) error { return nil } -// writes to a buffer, used for logging purposes -func writeBody(resp *http.Response, b *bytes.Buffer) error { - ct := resp.Header.Get(headerContentType) - if ct == "" { - fmt.Fprint(b, " Response contained no body\n") - return nil - } else if !shouldLogBody(b, ct) { - return nil - } - body, err := Payload(resp) - if err != nil { - fmt.Fprintf(b, " Failed to read response body: %s\n", err.Error()) - return err - } - if len(body) > 0 { - logBody(b, body) - } else { - fmt.Fprint(b, " Response contained no body\n") - } - return nil -} - -// RetryAfter returns non-zero if the response contains a Retry-After header value. -func RetryAfter(resp *http.Response) time.Duration { - if resp == nil { - return 0 - } - ra := resp.Header.Get(headerRetryAfter) - if ra == "" { - return 0 - } - // retry-after values are expressed in either number of - // seconds or an HTTP-date indicating when to try again - if retryAfter, _ := strconv.Atoi(ra); retryAfter > 0 { - return time.Duration(retryAfter) * time.Second - } else if t, err := time.Parse(time.RFC1123, ra); err == nil { - return time.Until(t) - } - return 0 -} - // DecodeByteArray will base-64 decode the provided string into v. func DecodeByteArray(s string, v *[]byte, format Base64Encoding) error { if len(s) == 0 { @@ -197,10 +149,10 @@ func DecodeByteArray(s string, v *[]byte, format Base64Encoding) error { // writeRequestWithResponse appends a formatted HTTP request into a Buffer. If request and/or err are // not nil, then these are also written into the Buffer. -func writeRequestWithResponse(b *bytes.Buffer, request *Request, resp *http.Response, err error) { +func writeRequestWithResponse(b *bytes.Buffer, req *policy.Request, resp *http.Response, err error) { // Write the request into the buffer. - fmt.Fprint(b, " "+request.Method+" "+request.URL.String()+"\n") - writeHeader(b, request.Header) + fmt.Fprint(b, " "+req.Raw().Method+" "+req.Raw().URL.String()+"\n") + writeHeader(b, req.Raw().Header) if resp != nil { fmt.Fprintln(b, " --------------------------------------------------------------------------------") fmt.Fprint(b, " RESPONSE Status: "+resp.Status+"\n") diff --git a/sdk/azcore/response_test.go b/sdk/azcore/runtime/response_test.go similarity index 88% rename from sdk/azcore/response_test.go rename to sdk/azcore/runtime/response_test.go index 9e325fa048bc..cfd867a74379 100644 --- a/sdk/azcore/response_test.go +++ b/sdk/azcore/runtime/response_test.go @@ -4,13 +4,12 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -package azcore +package runtime import ( "context" "net/http" "testing" - "time" "github.com/Azure/azure-sdk-for-go/sdk/internal/mock" ) @@ -152,33 +151,6 @@ func TestResponseUnmarshalXMLNoBody(t *testing.T) { } } -func TestRetryAfter(t *testing.T) { - resp := &http.Response{ - Header: http.Header{}, - } - if d := RetryAfter(resp); d > 0 { - t.Fatalf("unexpected retry-after value %d", d) - } - resp.Header.Set(headerRetryAfter, "300") - d := RetryAfter(resp) - if d <= 0 { - t.Fatal("expected retry-after value from seconds") - } - if d != 300*time.Second { - t.Fatalf("expected 300 seconds, got %d", d/time.Second) - } - atDate := time.Now().Add(600 * time.Second) - resp.Header.Set(headerRetryAfter, atDate.Format(time.RFC1123)) - d = RetryAfter(resp) - if d <= 0 { - t.Fatal("expected retry-after value from date") - } - // d will not be exactly 600 seconds but it will be close - if s := d / time.Second; s < 598 || s > 602 { - t.Fatalf("expected ~600 seconds, got %d", s) - } -} - func TestResponseUnmarshalAsByteArrayURLFormat(t *testing.T) { srv, close := mock.NewServer() defer close() diff --git a/sdk/azcore/runtime/transport_default_http_client.go b/sdk/azcore/runtime/transport_default_http_client.go new file mode 100644 index 000000000000..4352f916c455 --- /dev/null +++ b/sdk/azcore/runtime/transport_default_http_client.go @@ -0,0 +1,35 @@ +//go:build go1.16 +// +build go1.16 + +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package runtime + +import ( + "crypto/tls" + "net/http" + + "github.com/Azure/azure-sdk-for-go/sdk/azcore/policy" +) + +var defaultHTTPClient *http.Client + +func init() { + defaultTransport := http.DefaultTransport.(*http.Transport).Clone() + defaultTransport.TLSClientConfig.MinVersion = tls.VersionTLS12 + defaultHTTPClient = &http.Client{ + Transport: defaultTransport, + } +} + +// AuthenticationOptions contains various options used to create a credential policy. +type AuthenticationOptions struct { + // TokenRequest is a TokenRequestOptions that includes a scopes field which contains + // the list of OAuth2 authentication scopes used when requesting a token. + // This field is ignored for other forms of authentication (e.g. shared key). + TokenRequest policy.TokenRequestOptions + // AuxiliaryTenants contains a list of additional tenant IDs to be used to authenticate + // in cross-tenant applications. + AuxiliaryTenants []string +} diff --git a/sdk/azcore/progress.go b/sdk/azcore/streaming/progress.go similarity index 78% rename from sdk/azcore/progress.go rename to sdk/azcore/streaming/progress.go index cfdd2bf1d902..ca0b05c80812 100644 --- a/sdk/azcore/progress.go +++ b/sdk/azcore/streaming/progress.go @@ -4,21 +4,28 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -package azcore +package streaming import ( "io" + + "github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/shared" ) type progress struct { rc io.ReadCloser - rsc ReadSeekCloser + rsc io.ReadSeekCloser pr func(bytesTransferred int64) offset int64 } +// NopCloser returns a ReadSeekCloser with a no-op close method wrapping the provided io.ReadSeeker. +func NopCloser(rs io.ReadSeeker) io.ReadSeekCloser { + return shared.NopCloser(rs) +} + // NewRequestProgress adds progress reporting to an HTTP request's body stream. -func NewRequestProgress(body ReadSeekCloser, pr func(bytesTransferred int64)) ReadSeekCloser { +func NewRequestProgress(body io.ReadSeekCloser, pr func(bytesTransferred int64)) io.ReadSeekCloser { return &progress{ rc: body, rsc: body, diff --git a/sdk/azcore/progress_test.go b/sdk/azcore/streaming/progress_test.go similarity index 89% rename from sdk/azcore/progress_test.go rename to sdk/azcore/streaming/progress_test.go index bcf6e7abde14..cf68bdf5eb19 100644 --- a/sdk/azcore/progress_test.go +++ b/sdk/azcore/streaming/progress_test.go @@ -4,7 +4,7 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -package azcore +package streaming import ( "bytes" @@ -15,6 +15,7 @@ import ( "reflect" "testing" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/runtime" "github.com/Azure/azure-sdk-for-go/sdk/internal/mock" ) @@ -28,8 +29,8 @@ func TestProgressReporting(t *testing.T) { srv, close := mock.NewServer() defer close() srv.SetResponse(mock.WithBody(content)) - pl := NewPipeline(srv) - req, err := NewRequest(context.Background(), http.MethodGet, srv.URL()) + pl := runtime.NewPipeline(srv) + req, err := runtime.NewRequest(context.Background(), http.MethodGet, srv.URL()) if err != nil { t.Fatalf("unexpected error: %v", err) } @@ -77,8 +78,8 @@ func TestProgressReportingSeek(t *testing.T) { srv, close := mock.NewServer() defer close() srv.SetResponse(mock.WithBody(content)) - pl := NewPipeline(srv) - req, err := NewRequest(context.Background(), http.MethodGet, srv.URL()) + pl := runtime.NewPipeline(srv) + req, err := runtime.NewRequest(context.Background(), http.MethodGet, srv.URL()) if err != nil { t.Fatalf("unexpected error: %v", err) } diff --git a/sdk/azcore/to/to.go b/sdk/azcore/to/to.go new file mode 100644 index 000000000000..01bb033ef03c --- /dev/null +++ b/sdk/azcore/to/to.go @@ -0,0 +1,107 @@ +//go:build go1.16 +// +build go1.16 + +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package to + +import "time" + +// BoolPtr returns a pointer to the provided bool. +func BoolPtr(b bool) *bool { + return &b +} + +// Float32Ptr returns a pointer to the provided float32. +func Float32Ptr(i float32) *float32 { + return &i +} + +// Float64Ptr returns a pointer to the provided float64. +func Float64Ptr(i float64) *float64 { + return &i +} + +// Int32Ptr returns a pointer to the provided int32. +func Int32Ptr(i int32) *int32 { + return &i +} + +// Int64Ptr returns a pointer to the provided int64. +func Int64Ptr(i int64) *int64 { + return &i +} + +// StringPtr returns a pointer to the provided string. +func StringPtr(s string) *string { + return &s +} + +// TimePtr returns a pointer to the provided time.Time. +func TimePtr(t time.Time) *time.Time { + return &t +} + +// Int32PtrArray returns an array of *int32 from the specified values. +func Int32PtrArray(vals ...int32) []*int32 { + arr := make([]*int32, len(vals)) + for i := range vals { + arr[i] = Int32Ptr(vals[i]) + } + return arr +} + +// Int64PtrArray returns an array of *int64 from the specified values. +func Int64PtrArray(vals ...int64) []*int64 { + arr := make([]*int64, len(vals)) + for i := range vals { + arr[i] = Int64Ptr(vals[i]) + } + return arr +} + +// Float32PtrArray returns an array of *float32 from the specified values. +func Float32PtrArray(vals ...float32) []*float32 { + arr := make([]*float32, len(vals)) + for i := range vals { + arr[i] = Float32Ptr(vals[i]) + } + return arr +} + +// Float64PtrArray returns an array of *float64 from the specified values. +func Float64PtrArray(vals ...float64) []*float64 { + arr := make([]*float64, len(vals)) + for i := range vals { + arr[i] = Float64Ptr(vals[i]) + } + return arr +} + +// BoolPtrArray returns an array of *bool from the specified values. +func BoolPtrArray(vals ...bool) []*bool { + arr := make([]*bool, len(vals)) + for i := range vals { + arr[i] = BoolPtr(vals[i]) + } + return arr +} + +// StringPtrArray returns an array of *string from the specified values. +func StringPtrArray(vals ...string) []*string { + arr := make([]*string, len(vals)) + for i := range vals { + arr[i] = StringPtr(vals[i]) + } + return arr +} + +// TimePtrArray returns an array of *time.Time from the specified values. +func TimePtrArray(vals ...time.Time) []*time.Time { + arr := make([]*time.Time, len(vals)) + for i := range vals { + arr[i] = TimePtr(vals[i]) + } + return arr +} diff --git a/sdk/azcore/to/to_test.go b/sdk/azcore/to/to_test.go new file mode 100644 index 000000000000..ef9374b0ceda --- /dev/null +++ b/sdk/azcore/to/to_test.go @@ -0,0 +1,192 @@ +//go:build go1.16 +// +build go1.16 + +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package to + +import ( + "fmt" + "reflect" + "strconv" + "testing" + "time" +) + +func TestBoolPtr(t *testing.T) { + b := true + pb := BoolPtr(b) + if pb == nil { + t.Fatal("unexpected nil conversion") + } + if *pb != b { + t.Fatalf("got %v, want %v", *pb, b) + } +} + +func TestFloat32Ptr(t *testing.T) { + f32 := float32(3.1415926) + pf32 := Float32Ptr(f32) + if pf32 == nil { + t.Fatal("unexpected nil conversion") + } + if *pf32 != f32 { + t.Fatalf("got %v, want %v", *pf32, f32) + } +} + +func TestFloat64Ptr(t *testing.T) { + f64 := float64(2.71828182845904) + pf64 := Float64Ptr(f64) + if pf64 == nil { + t.Fatal("unexpected nil conversion") + } + if *pf64 != f64 { + t.Fatalf("got %v, want %v", *pf64, f64) + } +} + +func TestInt32Ptr(t *testing.T) { + i32 := int32(123456789) + pi32 := Int32Ptr(i32) + if pi32 == nil { + t.Fatal("unexpected nil conversion") + } + if *pi32 != i32 { + t.Fatalf("got %v, want %v", *pi32, i32) + } +} + +func TestInt64Ptr(t *testing.T) { + i64 := int64(9876543210) + pi64 := Int64Ptr(i64) + if pi64 == nil { + t.Fatal("unexpected nil conversion") + } + if *pi64 != i64 { + t.Fatalf("got %v, want %v", *pi64, i64) + } +} + +func TestStringPtr(t *testing.T) { + s := "the string" + ps := StringPtr(s) + if ps == nil { + t.Fatal("unexpected nil conversion") + } + if *ps != s { + t.Fatalf("got %v, want %v", *ps, s) + } +} + +func TestTimePtr(t *testing.T) { + tt := time.Now() + pt := TimePtr(tt) + if pt == nil { + t.Fatal("unexpected nil conversion") + } + if *pt != tt { + t.Fatalf("got %v, want %v", *pt, tt) + } +} + +func TestInt32PtrArray(t *testing.T) { + arr := Int32PtrArray() + if len(arr) != 0 { + t.Fatal("expected zero length") + } + arr = Int32PtrArray(1, 2, 3, 4, 5) + for i, v := range arr { + if *v != int32(i+1) { + t.Fatal("values don't match") + } + } +} + +func TestInt64PtrArray(t *testing.T) { + arr := Int64PtrArray() + if len(arr) != 0 { + t.Fatal("expected zero length") + } + arr = Int64PtrArray(1, 2, 3, 4, 5) + for i, v := range arr { + if *v != int64(i+1) { + t.Fatal("values don't match") + } + } +} + +func TestFloat32PtrArray(t *testing.T) { + arr := Float32PtrArray() + if len(arr) != 0 { + t.Fatal("expected zero length") + } + arr = Float32PtrArray(1.1, 2.2, 3.3, 4.4, 5.5) + for i, v := range arr { + f, err := strconv.ParseFloat(fmt.Sprintf("%d.%d", i+1, i+1), 32) + if err != nil { + t.Fatal(err) + } + if *v != float32(f) { + t.Fatal("values don't match") + } + } +} + +func TestFloat64PtrArray(t *testing.T) { + arr := Float64PtrArray() + if len(arr) != 0 { + t.Fatal("expected zero length") + } + arr = Float64PtrArray(1.1, 2.2, 3.3, 4.4, 5.5) + for i, v := range arr { + f, err := strconv.ParseFloat(fmt.Sprintf("%d.%d", i+1, i+1), 64) + if err != nil { + t.Fatal(err) + } + if *v != f { + t.Fatal("values don't match") + } + } +} + +func TestBoolPtrArray(t *testing.T) { + arr := BoolPtrArray() + if len(arr) != 0 { + t.Fatal("expected zero length") + } + arr = BoolPtrArray(true, false, true) + curr := true + for _, v := range arr { + if *v != curr { + t.Fatal("values don'p match") + } + curr = !curr + } +} + +func TestStringPtrArray(t *testing.T) { + arr := StringPtrArray() + if len(arr) != 0 { + t.Fatal("expected zero length") + } + arr = StringPtrArray("one", "", "three") + if !reflect.DeepEqual(arr, []*string{StringPtr("one"), StringPtr(""), StringPtr("three")}) { + t.Fatal("values don't match") + } +} + +func TestTimePtrArray(t *testing.T) { + arr := TimePtrArray() + if len(arr) != 0 { + t.Fatal("expected zero length") + } + t1 := time.Now() + t2 := time.Time{} + t3 := t1.Add(24 * time.Hour) + arr = TimePtrArray(t1, t2, t3) + if !reflect.DeepEqual(arr, []*time.Time{&t1, &t2, &t3}) { + t.Fatal("values don't match") + } +} diff --git a/sdk/azcore/transport_default_http_client.go b/sdk/azcore/transport_default_http_client.go deleted file mode 100644 index 02a36bcbe741..000000000000 --- a/sdk/azcore/transport_default_http_client.go +++ /dev/null @@ -1,22 +0,0 @@ -//go:build go1.16 -// +build go1.16 - -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -package azcore - -import ( - "crypto/tls" - "net/http" -) - -var defaultHTTPClient *http.Client - -func init() { - defaultTransport := http.DefaultTransport.(*http.Transport).Clone() - defaultTransport.TLSClientConfig.MinVersion = tls.VersionTLS12 - defaultHTTPClient = &http.Client{ - Transport: defaultTransport, - } -} diff --git a/sdk/azcore/version.go b/sdk/azcore/version.go deleted file mode 100644 index 2220616518bc..000000000000 --- a/sdk/azcore/version.go +++ /dev/null @@ -1,15 +0,0 @@ -//go:build go1.16 -// +build go1.16 - -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -package azcore - -const ( - // UserAgent is the string to be used in the user agent string when making requests. - UserAgent = "azcore/" + Version - - // Version is the semantic version (see http://semver.org) of this module. - Version = "v0.18.0" -)