diff --git a/doc.go b/doc.go index 57fb06da81..853036c026 100644 --- a/doc.go +++ b/doc.go @@ -42,6 +42,8 @@ Call the Elasticsearch APIs by invoking the corresponding methods on the client: log.Println(res) -See the github.com/elastic/go-elasticsearch/esapi package for more information and examples. +See the github.com/elastic/go-elasticsearch/esapi package for more information about using the API. + +See the github.com/elastic/go-elasticsearch/estransport package for more information about configuring the transport. */ package elasticsearch diff --git a/elasticsearch.go b/elasticsearch.go index 334918e59c..c38d8a321e 100644 --- a/elasticsearch.go +++ b/elasticsearch.go @@ -12,6 +12,7 @@ import ( "net/url" "os" "strings" + "time" "github.com/elastic/go-elasticsearch/v7/esapi" "github.com/elastic/go-elasticsearch/v7/estransport" @@ -36,6 +37,13 @@ type Config struct { CloudID string // Endpoint for the Elastic Service (https://elastic.co/cloud). APIKey string // Base64-encoded token for authorization; if set, overrides username and password. + RetryOnStatus []int // List of status codes for retry. Default: 502, 503, 504. + DisableRetry bool // Default: false. + EnableRetryOnTimeout bool // Default: false. + MaxRetries int // Default: 3. + + RetryBackoff func(attempt int) time.Duration // Optional backoff duration. Default: nil. + Transport http.RoundTripper // The HTTP transport object. Logger estransport.Logger // The logger object. } @@ -116,6 +124,12 @@ func NewClient(cfg Config) (*Client, error) { Password: cfg.Password, APIKey: cfg.APIKey, + RetryOnStatus: cfg.RetryOnStatus, + DisableRetry: cfg.DisableRetry, + EnableRetryOnTimeout: cfg.EnableRetryOnTimeout, + MaxRetries: cfg.MaxRetries, + RetryBackoff: cfg.RetryBackoff, + Transport: cfg.Transport, Logger: cfg.Logger, }) diff --git a/esapi/doc.go b/esapi/doc.go index 12d0815e43..f522175f99 100644 --- a/esapi/doc.go +++ b/esapi/doc.go @@ -85,10 +85,11 @@ about the API endpoints and parameters. The Go API is generated from the Elasticsearch JSON specification at https://github.com/elastic/elasticsearch/tree/master/rest-api-spec/src/main/resources/rest-api-spec/api by the internal package available at -https://github.com/elastic/go-elasticsearch/tree/master/internal/cmd/generate/commands. +https://github.com/elastic/go-elasticsearch/tree/master/internal/cmd/generate/commands/gensource. The API is tested by integration tests common to all Elasticsearch official clients, generated from the -source at https://github.com/elastic/elasticsearch/tree/master/rest-api-spec/src/main/resources/rest-api-spec/test. The generator is provided by the internal package internal/cmd/generate. +source at https://github.com/elastic/elasticsearch/tree/master/rest-api-spec/src/main/resources/rest-api-spec/test. +The generator is provided by the internal package available at internal/cmd/generate/commands/gentests. */ package esapi diff --git a/esapi/esapi.request.go b/esapi/esapi.request.go index 250c350332..d967fc61f5 100644 --- a/esapi/esapi.request.go +++ b/esapi/esapi.request.go @@ -5,13 +5,9 @@ package esapi import ( - "bytes" "context" "io" - "io/ioutil" "net/http" - "net/url" - "strings" ) const ( @@ -31,30 +27,5 @@ type Request interface { // newRequest creates an HTTP request. // func newRequest(method, path string, body io.Reader) (*http.Request, error) { - r := http.Request{ - Method: method, - URL: &url.URL{Path: path}, - Proto: "HTTP/1.1", - ProtoMajor: 1, - ProtoMinor: 1, - Header: make(http.Header), - } - - if body != nil { - switch b := body.(type) { - case *bytes.Buffer: - r.Body = ioutil.NopCloser(body) - r.ContentLength = int64(b.Len()) - case *bytes.Reader: - r.Body = ioutil.NopCloser(body) - r.ContentLength = int64(b.Len()) - case *strings.Reader: - r.Body = ioutil.NopCloser(body) - r.ContentLength = int64(b.Len()) - default: - r.Body = ioutil.NopCloser(body) - } - } - - return &r, nil + return http.NewRequest(method, path, body) } diff --git a/estransport/doc.go b/estransport/doc.go index aa3eba3f95..e65e34284f 100644 --- a/estransport/doc.go +++ b/estransport/doc.go @@ -4,12 +4,23 @@ Package estransport provides the transport layer for the Elasticsearch client. It is automatically included in the client provided by the github.com/elastic/go-elasticsearch package and is not intended for direct use: to configure the client, use the elasticsearch.Config struct. -The default HTTP transport of the client is http.Transport. +The default HTTP transport of the client is http.Transport; use the Transport option to customize it; +see the _examples/customization.go file in this repository for information. The package defines the "Selector" interface for getting a URL from the list. At the moment, the implementation is rather minimal: the client takes a slice of url.URL pointers, and round-robins across them when performing the request. +The package will automatically retry requests on network-related errors, and on specific +response status codes (by default 502, 503, 504). Use the RetryOnStatus option to customize the list. +The transport will not retry a timeout network error, unless enabled by setting EnableRetryOnTimeout to true. + +Use the MaxRetries option to configure the number of retries, and set DisableRetry to true +to disable the retry behaviour altogether. + +By default, the retry will be performed without any delay; to configure a backoff interval, +implement the RetryBackoff option function; see an example in the package unit tests for information. + The package defines the "Logger" interface for logging information about request and response. It comes with several bundled loggers for logging in text and JSON. diff --git a/estransport/estransport.go b/estransport/estransport.go index e46d13fce2..5500ea8a6c 100644 --- a/estransport/estransport.go +++ b/estransport/estransport.go @@ -9,6 +9,7 @@ import ( "fmt" "io" "io/ioutil" + "net" "net/http" "net/url" "regexp" @@ -26,6 +27,9 @@ const Version = version.Client var ( userAgent string reGoVersion = regexp.MustCompile(`go(\d+\.\d+\..+)`) + + defaultMaxRetries = 3 + defaultRetryOnStatus = [...]int{502, 503, 504} ) func init() { @@ -46,6 +50,12 @@ type Config struct { Password string APIKey string + RetryOnStatus []int + DisableRetry bool + EnableRetryOnTimeout bool + MaxRetries int + RetryBackoff func(attempt int) time.Duration + Transport http.RoundTripper Logger Logger } @@ -58,6 +68,12 @@ type Client struct { password string apikey string + retryOnStatus []int + disableRetry bool + enableRetryOnTimeout bool + maxRetries int + retryBackoff func(attempt int) time.Duration + transport http.RoundTripper selector Selector logger Logger @@ -72,12 +88,26 @@ func New(cfg Config) *Client { cfg.Transport = http.DefaultTransport } + if len(cfg.RetryOnStatus) == 0 { + cfg.RetryOnStatus = defaultRetryOnStatus[:] + } + + if cfg.MaxRetries == 0 { + cfg.MaxRetries = defaultMaxRetries + } + return &Client{ urls: cfg.URLs, username: cfg.Username, password: cfg.Password, apikey: cfg.APIKey, + retryOnStatus: cfg.RetryOnStatus, + disableRetry: cfg.DisableRetry, + enableRetryOnTimeout: cfg.EnableRetryOnTimeout, + maxRetries: cfg.MaxRetries, + retryBackoff: cfg.RetryBackoff, + transport: cfg.Transport, selector: NewRoundRobinSelector(cfg.URLs...), logger: cfg.Logger, @@ -88,41 +118,86 @@ func New(cfg Config) *Client { // func (c *Client) Perform(req *http.Request) (*http.Response, error) { var ( - dupReqBody io.Reader - ) + res *http.Response + err error - // Get URL from the Selector - // - u, err := c.getURL() - if err != nil { - // TODO(karmi): Log error - return nil, fmt.Errorf("cannot get URL: %s", err) - } + dupReqBodyForLog io.ReadCloser + ) // Update request // - c.setURL(u, req) - c.setUserAgent(req) - c.setAuthorization(u, req) + c.setReqUserAgent(req) + + for i := 1; i <= c.maxRetries; i++ { + var ( + nodeURL *url.URL + shouldRetry bool + ) + + // Get URL from the Selector + // + nodeURL, err = c.getURL() + if err != nil { + // TODO(karmi): Log error + return nil, fmt.Errorf("cannot get URL: %s", err) + } - // Duplicate request body for logger - // - if c.logger != nil && c.logger.RequestBodyEnabled() { - if req.Body != nil && req.Body != http.NoBody { - dupReqBody, req.Body, _ = duplicateBody(req.Body) + // Update request + // + c.setReqURL(nodeURL, req) + c.setReqAuth(nodeURL, req) + + // Duplicate request body for logger + // + if c.logger != nil && c.logger.RequestBodyEnabled() { + if req.Body != nil && req.Body != http.NoBody { + dupReqBodyForLog, req.Body, _ = duplicateBody(req.Body) + } } - } - // Set up time measures and execute the request - // - start := time.Now().UTC() - res, err := c.transport.RoundTrip(req) - dur := time.Since(start) + // Set up time measures and execute the request + // + start := time.Now().UTC() + res, err = c.transport.RoundTrip(req) + dur := time.Since(start) - // Log request and response - // - if c.logger != nil { - c.logRoundTrip(req, res, dupReqBody, err, start, dur) + // Log request and response + // + if c.logger != nil { + c.logRoundTrip(req, res, dupReqBodyForLog, err, start, dur) + } + + // Retry only on network errors, but don't retry on timeout errors, unless configured + // + if err != nil { + if err, ok := err.(net.Error); ok { + if (!err.Timeout() || c.enableRetryOnTimeout) && !c.disableRetry { + shouldRetry = true + } + } + } + + // Retry on configured response statuses + // + if res != nil && !c.disableRetry { + for _, code := range c.retryOnStatus { + if res.StatusCode == code { + shouldRetry = true + } + } + } + + // Break if retry should not be performed + // + if !shouldRetry { + break + } + + // Delay the retry if a backoff function is configured + // + if c.retryBackoff != nil { + time.Sleep(c.retryBackoff(i)) + } } // TODO(karmi): Wrap error @@ -139,7 +214,7 @@ func (c *Client) getURL() (*url.URL, error) { return c.selector.Select() } -func (c *Client) setURL(u *url.URL, req *http.Request) *http.Request { +func (c *Client) setReqURL(u *url.URL, req *http.Request) *http.Request { req.URL.Scheme = u.Scheme req.URL.Host = u.Host @@ -154,7 +229,7 @@ func (c *Client) setURL(u *url.URL, req *http.Request) *http.Request { return req } -func (c *Client) setAuthorization(u *url.URL, req *http.Request) *http.Request { +func (c *Client) setReqAuth(u *url.URL, req *http.Request) *http.Request { if _, ok := req.Header["Authorization"]; !ok { if u.User != nil { password, _ := u.User.Password() @@ -180,7 +255,7 @@ func (c *Client) setAuthorization(u *url.URL, req *http.Request) *http.Request { return req } -func (c *Client) setUserAgent(req *http.Request) *http.Request { +func (c *Client) setReqUserAgent(req *http.Request) *http.Request { req.Header.Set("User-Agent", userAgent) return req } diff --git a/estransport/estransport_internal_test.go b/estransport/estransport_internal_test.go index 4a2b812a27..eb6f28eb1e 100644 --- a/estransport/estransport_internal_test.go +++ b/estransport/estransport_internal_test.go @@ -10,8 +10,10 @@ import ( "fmt" "net/http" "net/url" + "reflect" "strings" "testing" + "time" ) var ( @@ -26,6 +28,11 @@ func (t *mockTransp) RoundTrip(req *http.Request) (*http.Response, error) { return t.RoundTripFunc(req) } +type mockNetError struct{ error } + +func (e *mockNetError) Timeout() bool { return false } +func (e *mockNetError) Temporary() bool { return false } + func TestTransport(t *testing.T) { t.Run("Interface", func(t *testing.T) { var _ Interface = New(Config{}) @@ -61,6 +68,53 @@ func TestTransport(t *testing.T) { }) } +func TestTransportConfig(t *testing.T) { + t.Run("Defaults", func(t *testing.T) { + tp := New(Config{}) + + if !reflect.DeepEqual(tp.retryOnStatus, []int{502, 503, 504}) { + t.Errorf("Unexpected retryOnStatus: %v", tp.retryOnStatus) + } + + if tp.disableRetry { + t.Errorf("Unexpected disableRetry: %v", tp.disableRetry) + } + + if tp.enableRetryOnTimeout { + t.Errorf("Unexpected enableRetryOnTimeout: %v", tp.enableRetryOnTimeout) + } + + if tp.maxRetries != 3 { + t.Errorf("Unexpected maxRetries: %v", tp.maxRetries) + } + }) + + t.Run("Custom", func(t *testing.T) { + tp := New(Config{ + RetryOnStatus: []int{404, 408}, + DisableRetry: true, + EnableRetryOnTimeout: true, + MaxRetries: 5, + }) + + if !reflect.DeepEqual(tp.retryOnStatus, []int{404, 408}) { + t.Errorf("Unexpected retryOnStatus: %v", tp.retryOnStatus) + } + + if !tp.disableRetry { + t.Errorf("Unexpected disableRetry: %v", tp.disableRetry) + } + + if !tp.enableRetryOnTimeout { + t.Errorf("Unexpected enableRetryOnTimeout: %v", tp.enableRetryOnTimeout) + } + + if tp.maxRetries != 5 { + t.Errorf("Unexpected maxRetries: %v", tp.maxRetries) + } + }) +} + func TestTransportPerform(t *testing.T) { t.Run("Executes", func(t *testing.T) { u, _ := url.Parse("https://foo.com/bar") @@ -87,7 +141,7 @@ func TestTransportPerform(t *testing.T) { tp := New(Config{URLs: []*url.URL{u}}) req, _ := http.NewRequest("GET", "/abc", nil) - tp.setURL(u, req) + tp.setReqURL(u, req) expected := "https://foo.com/bar/abc" @@ -101,7 +155,7 @@ func TestTransportPerform(t *testing.T) { tp := New(Config{URLs: []*url.URL{u}}) req, _ := http.NewRequest("GET", "/", nil) - tp.setAuthorization(u, req) + tp.setReqAuth(u, req) username, password, ok := req.BasicAuth() if !ok { @@ -118,7 +172,7 @@ func TestTransportPerform(t *testing.T) { tp := New(Config{URLs: []*url.URL{u}, Username: "foo", Password: "bar"}) req, _ := http.NewRequest("GET", "/", nil) - tp.setAuthorization(u, req) + tp.setReqAuth(u, req) username, password, ok := req.BasicAuth() if !ok { @@ -135,7 +189,7 @@ func TestTransportPerform(t *testing.T) { tp := New(Config{URLs: []*url.URL{u}, APIKey: "Zm9vYmFy"}) // foobar req, _ := http.NewRequest("GET", "/", nil) - tp.setAuthorization(u, req) + tp.setReqAuth(u, req) value := req.Header.Get("Authorization") if value == "" { @@ -152,7 +206,7 @@ func TestTransportPerform(t *testing.T) { tp := New(Config{URLs: []*url.URL{u}}) req, _ := http.NewRequest("GET", "/abc", nil) - tp.setUserAgent(req) + tp.setReqUserAgent(req) if !strings.HasPrefix(req.UserAgent(), "go-elasticsearch") { t.Errorf("Unexpected user agent: %s", req.UserAgent()) @@ -175,6 +229,233 @@ func TestTransportPerform(t *testing.T) { }) } +func TestTransportPerformRetries(t *testing.T) { + t.Run("Retry request on network error and return the response", func(t *testing.T) { + var ( + i int + numReqs = 2 + ) + + u, _ := url.Parse("http://foo.bar") + tp := New(Config{ + URLs: []*url.URL{u, u, u}, + Transport: &mockTransp{ + RoundTripFunc: func(req *http.Request) (*http.Response, error) { + i++ + fmt.Printf("Request #%d", i) + if i == numReqs { + fmt.Print(": OK\n") + return &http.Response{Status: "OK"}, nil + } + fmt.Print(": ERR\n") + return nil, &mockNetError{error: fmt.Errorf("Mock network error (%d)", i)} + }, + }}) + + req, _ := http.NewRequest("GET", "/abc", nil) + + res, err := tp.Perform(req) + + if err != nil { + t.Fatalf("Unexpected error: %s", err) + } + + if res.Status != "OK" { + t.Errorf("Unexpected response: %+v", res) + } + + if i != numReqs { + t.Errorf("Unexpected number of requests, want=%d, got=%d", numReqs, i) + } + }) + + t.Run("Retry request on 5xx response and return new response", func(t *testing.T) { + var ( + i int + numReqs = 2 + ) + + u, _ := url.Parse("http://foo.bar") + tp := New(Config{ + URLs: []*url.URL{u, u, u}, + Transport: &mockTransp{ + RoundTripFunc: func(req *http.Request) (*http.Response, error) { + i++ + fmt.Printf("Request #%d", i) + if i == numReqs { + fmt.Print(": 200\n") + return &http.Response{StatusCode: 200}, nil + } + fmt.Print(": 502\n") + return &http.Response{StatusCode: 502}, nil + }, + }}) + + req, _ := http.NewRequest("GET", "/abc", nil) + + res, err := tp.Perform(req) + + if err != nil { + t.Fatalf("Unexpected error: %s", err) + } + + if res.StatusCode != 200 { + t.Errorf("Unexpected response: %+v", res) + } + + if i != numReqs { + t.Errorf("Unexpected number of requests, want=%d, got=%d", numReqs, i) + } + }) + + t.Run("Retry request and return error when max retries exhausted", func(t *testing.T) { + var ( + i int + numReqs = 3 + ) + + u, _ := url.Parse("http://foo.bar") + tp := New(Config{ + URLs: []*url.URL{u, u, u}, + Transport: &mockTransp{ + RoundTripFunc: func(req *http.Request) (*http.Response, error) { + i++ + fmt.Printf("Request #%d", i) + fmt.Print(": ERR\n") + return nil, &mockNetError{error: fmt.Errorf("Mock network error (%d)", i)} + }, + }}) + + req, _ := http.NewRequest("GET", "/abc", nil) + + res, err := tp.Perform(req) + + if err == nil { + t.Fatalf("Expected error, got: %v", err) + } + + if res != nil { + t.Errorf("Unexpected response: %+v", res) + } + + if i != numReqs { + t.Errorf("Unexpected number of requests, want=%d, got=%d", numReqs, i) + } + }) + + t.Run("Don't retry request on regular error", func(t *testing.T) { + var i int + + u, _ := url.Parse("http://foo.bar") + tp := New(Config{ + URLs: []*url.URL{u, u, u}, + Transport: &mockTransp{ + RoundTripFunc: func(req *http.Request) (*http.Response, error) { + i++ + fmt.Printf("Request #%d", i) + fmt.Print(": ERR\n") + return nil, fmt.Errorf("Mock regular error (%d)", i) + }, + }}) + + req, _ := http.NewRequest("GET", "/abc", nil) + + res, err := tp.Perform(req) + + if err == nil { + t.Fatalf("Expected error, got: %v", err) + } + + if res != nil { + t.Errorf("Unexpected response: %+v", res) + } + + if i != 1 { + t.Errorf("Unexpected number of requests, want=%d, got=%d", 1, i) + } + }) + + t.Run("Don't retry request when retries are disabled", func(t *testing.T) { + var i int + + u, _ := url.Parse("http://foo.bar") + tp := New(Config{ + URLs: []*url.URL{u, u, u}, + Transport: &mockTransp{ + RoundTripFunc: func(req *http.Request) (*http.Response, error) { + i++ + fmt.Printf("Request #%d", i) + fmt.Print(": ERR\n") + return nil, &mockNetError{error: fmt.Errorf("Mock network error (%d)", i)} + }, + }, + DisableRetry: true, + }) + + req, _ := http.NewRequest("GET", "/abc", nil) + tp.Perform(req) + + if i != 1 { + t.Errorf("Unexpected number of requests, want=%d, got=%d", 1, i) + } + }) + + t.Run("Delay the retry with a backoff function", func(t *testing.T) { + var ( + i int + numReqs = 3 + start = time.Now() + expectedDuration = time.Duration(numReqs*100) * time.Millisecond + ) + + u, _ := url.Parse("http://foo.bar") + tp := New(Config{ + URLs: []*url.URL{u, u, u}, + Transport: &mockTransp{ + RoundTripFunc: func(req *http.Request) (*http.Response, error) { + i++ + fmt.Printf("Request #%d", i) + if i == numReqs { + fmt.Print(": OK\n") + return &http.Response{Status: "OK"}, nil + } + fmt.Print(": ERR\n") + return nil, &mockNetError{error: fmt.Errorf("Mock network error (%d)", i)} + }, + }, + + // A simple incremental backoff function + // + RetryBackoff: func(i int) time.Duration { + d := time.Duration(i) * 100 * time.Millisecond + fmt.Printf("Attempt: %d | Sleeping for %s...\n", i, d) + return d + }, + }) + + req, _ := http.NewRequest("GET", "/abc", nil) + + res, err := tp.Perform(req) + end := time.Since(start) + + if err != nil { + t.Fatalf("Unexpected error: %s", err) + } + + if res.Status != "OK" { + t.Errorf("Unexpected response: %+v", res) + } + + if i != numReqs { + t.Errorf("Unexpected number of requests, want=%d, got=%d", numReqs, i) + } + + if end < expectedDuration { + t.Errorf("Unexpected duration, want=>%s, got=%s", expectedDuration, end) + } + }) +} + func TestTransportSelector(t *testing.T) { t.Run("Nil value", func(t *testing.T) { tp := New(Config{URLs: []*url.URL{nil}}) diff --git a/estransport/logger.go b/estransport/logger.go index 85af2fce55..263c80c5a5 100644 --- a/estransport/logger.go +++ b/estransport/logger.go @@ -74,7 +74,12 @@ func (l *TextLogger) LogRoundTrip(req *http.Request, res *http.Response, err err ) if l.RequestBodyEnabled() && req != nil && req.Body != nil && req.Body != http.NoBody { var buf bytes.Buffer - buf.ReadFrom(req.Body) + if req.GetBody != nil { + b, _ := req.GetBody() + buf.ReadFrom(b) + } else { + buf.ReadFrom(req.Body) + } logBodyAsText(l.Output, &buf, ">") } if l.ResponseBodyEnabled() && res != nil && res.Body != nil && res.Body != http.NoBody { @@ -134,7 +139,12 @@ func (l *ColorLogger) LogRoundTrip(req *http.Request, res *http.Response, err er if l.RequestBodyEnabled() && req != nil && req.Body != nil && req.Body != http.NoBody { var buf bytes.Buffer - buf.ReadFrom(req.Body) + if req.GetBody != nil { + b, _ := req.GetBody() + buf.ReadFrom(b) + } else { + buf.ReadFrom(req.Body) + } fmt.Fprint(l.Output, "\x1b[2m") logBodyAsText(l.Output, &buf, " ยป") fmt.Fprint(l.Output, "\x1b[0m") @@ -211,7 +221,12 @@ func (l *CurlLogger) LogRoundTrip(req *http.Request, res *http.Response, err err if req != nil && req.Body != nil && req.Body != http.NoBody { var buf bytes.Buffer - buf.ReadFrom(req.Body) + if req.GetBody != nil { + b, _ := req.GetBody() + buf.ReadFrom(b) + } else { + buf.ReadFrom(req.Body) + } b.Grow(buf.Len()) b.WriteString(" -d \\\n'") @@ -314,7 +329,12 @@ func (l *JSONLogger) LogRoundTrip(req *http.Request, res *http.Response, err err appendQuote(req.Method) if l.RequestBodyEnabled() && req != nil && req.Body != nil && req.Body != http.NoBody { var buf bytes.Buffer - buf.ReadFrom(req.Body) + if req.GetBody != nil { + b, _ := req.GetBody() + buf.ReadFrom(b) + } else { + buf.ReadFrom(req.Body) + } b.Grow(buf.Len() + 8) b.WriteString(`,"body":`)