diff --git a/httpclient/api_client.go b/httpclient/api_client.go index e2e8a5d3e..d948392f8 100644 --- a/httpclient/api_client.go +++ b/httpclient/api_client.go @@ -88,7 +88,10 @@ func NewApiClient(cfg ClientConfig) *ApiClient { config: cfg, rateLimiter: rate.NewLimiter(rateLimit, 1), httpClient: &http.Client{ - Timeout: cfg.HTTPTimeout, + // We deal with request timeouts ourselves such that we do not + // time out during request or response body reads that make + // progress (e.g. on a slower network connection). + Timeout: 0, Transport: transport, }, } @@ -196,6 +199,13 @@ func (c *ApiClient) attempt( if err != nil { return c.failRequest(ctx, "failed in rate limiter", err) } + + pctx := ctx + + // This timeout context enables us to extend the request timeout + // while the request or response body is being read. + // It exists because the net/http package uses a fixed timeout regardless of payload size. + ctx, ticker := newTimeoutContext(pctx, c.config.HTTPTimeout) request, err := http.NewRequestWithContext(ctx, method, requestURL, requestBody.Reader) if err != nil { return c.failRequest(ctx, "failed creating new request", err) @@ -211,16 +221,36 @@ func (c *ApiClient) attempt( if request.Header.Get("Content-Type") == "" && requestBody.ContentType != "" { request.Header.Set("Content-Type", requestBody.ContentType) } + // If there is a request body, wrap it to extend the request timeout while it is being read. + // Note: we do not wrap the request body earlier, because [http.NewRequestWithContext] performs + // type probing on the body variable to determine the content length. + if request.Body != nil && request.Body != http.NoBody { + request.Body = newRequestBodyTicker(ticker, request.Body) + } // attempt the actual request response, err := c.httpClient.Do(request) // After this point, the request body has (probably) been consumed. handleError() must be called to reset it if // possible. - if _, ok := err.(*url.Error); ok { + if uerr, ok := err.(*url.Error); ok { + // If the timeout context has been canceled but the parent context hasn't, then the request has timed out. + if pctx.Err() == nil && uerr.Err == context.Canceled { + uerr.Err = fmt.Errorf("request timed out after %s of inactivity", c.config.HTTPTimeout) + } return c.handleError(ctx, err, requestBody) } + // If there is a response body, wrap it to extend the request timeout while it is being read. + if response != nil && response.Body != nil { + response.Body = newResponseBodyTicker(ticker, response.Body) + } else { + // If there is no response body, the request has completed and there + // is no need to extend the timeout. Cancel the context to clean up + // the underlying goroutine. + ticker.Cancel() + } + // By this point, the request body has certainly been consumed. responseWrapper, err := common.NewResponseWrapper(response, requestBody) if err != nil { diff --git a/httpclient/api_client_test.go b/httpclient/api_client_test.go index 725a4d9d8..9ef9eaf87 100644 --- a/httpclient/api_client_test.go +++ b/httpclient/api_client_test.go @@ -275,6 +275,55 @@ func TestSimpleRequestErrReaderCloseBody_StreamResponse(t *testing.T) { require.NoError(t, err, "response body should not be closed for streaming responses") } +func timeoutTransport(r *http.Request) (*http.Response, error) { + select { + case <-r.Context().Done(): + return nil, r.Context().Err() + case <-time.After(50 * time.Millisecond): + return nil, fmt.Errorf("test timeout") + } +} + +func TestSimpleRequestContextCancel(t *testing.T) { + ctx := context.Background() + ctx, cancel := context.WithCancel(ctx) + + // Cancel outer context after 10ms + go func() { + defer cancel() + time.Sleep(10 * time.Millisecond) + }() + + c := NewApiClient(ClientConfig{ + Transport: hc(timeoutTransport), + }) + err := c.Do(ctx, "GET", "/a", WithRequestData(map[string]any{})) + require.ErrorContains(t, err, "context canceled") +} + +func TestSimpleRequestContextDeadline(t *testing.T) { + ctx := context.Background() + ctx, cancel := context.WithDeadline(ctx, time.Now().Add(10*time.Millisecond)) + defer cancel() + + c := NewApiClient(ClientConfig{ + Transport: hc(timeoutTransport), + }) + err := c.Do(ctx, "GET", "/a", WithRequestData(map[string]any{})) + require.ErrorContains(t, err, "context deadline exceeded") +} + +func TestSimpleRequestTimeout(t *testing.T) { + ctx := context.Background() + + c := NewApiClient(ClientConfig{ + HTTPTimeout: 10 * time.Millisecond, + Transport: hc(timeoutTransport), + }) + err := c.Do(ctx, "GET", "/a", WithRequestData(map[string]any{})) + require.ErrorContains(t, err, "request timed out after 10ms of inactivity") +} + type BufferLogger struct { strings.Builder } diff --git a/httpclient/timeout_context.go b/httpclient/timeout_context.go new file mode 100644 index 000000000..ed20cf66b --- /dev/null +++ b/httpclient/timeout_context.go @@ -0,0 +1,117 @@ +package httpclient + +import ( + "context" + "io" + "sync" + "time" +) + +type timeoutContext struct { + ctx context.Context + cancel context.CancelFunc + + // Timeout is constant. + // Deadline is updated when Tick function is called. + timeout time.Duration + deadline time.Time + + // Protect against concurrent deadline reads/writes. + lock sync.Mutex +} + +type TimeoutTicker interface { + Tick() + Cancel() +} + +func newTimeoutContext(ctx context.Context, timeout time.Duration) (context.Context, TimeoutTicker) { + ctx, cancel := context.WithCancel(ctx) + t := &timeoutContext{ + ctx: ctx, + cancel: cancel, + timeout: timeout, + deadline: time.Now().Add(timeout), + } + + // Start goroutine to cancel the context when the deadline is reached. + go t.run() + return ctx, t +} + +// Tick updates the deadline to the current time plus the timeout. +func (t *timeoutContext) Tick() { + t.lock.Lock() + defer t.lock.Unlock() + t.deadline = time.Now().Add(t.timeout) +} + +// Cancel cancels the context. +func (t *timeoutContext) Cancel() { + t.cancel() +} + +// Deadline returns the current deadline. +func (t *timeoutContext) Deadline() time.Time { + t.lock.Lock() + defer t.lock.Unlock() + return t.deadline +} + +func (t *timeoutContext) run() { + for { + ttl := time.Until(t.Deadline()) + if ttl <= 0 { + t.cancel() + return + } + + timer := time.NewTimer(ttl) + select { + case <-timer.C: + // Check if the deadline has been updated + continue + case <-t.ctx.Done(): + timer.Stop() + return + } + } +} + +// tickingReadCloser wraps an io.ReadCloser and calls the tick function on each read. +type tickingReadCloser struct { + rc io.ReadCloser + t TimeoutTicker +} + +func (t tickingReadCloser) Read(p []byte) (n int, err error) { + defer t.t.Tick() + return t.rc.Read(p) +} + +func (t tickingReadCloser) Close() error { + return t.rc.Close() +} + +// cancellingReadCloser wraps an io.ReadCloser and calls the cancel function on close. +type cancellingReadCloser struct { + rc io.ReadCloser + t TimeoutTicker +} + +func (t cancellingReadCloser) Read(p []byte) (n int, err error) { + return t.rc.Read(p) +} + +func (t cancellingReadCloser) Close() error { + defer t.t.Cancel() + return t.rc.Close() +} + +func newRequestBodyTicker(t TimeoutTicker, r io.ReadCloser) io.ReadCloser { + return tickingReadCloser{r, t} +} + +func newResponseBodyTicker(t TimeoutTicker, r io.ReadCloser) io.ReadCloser { + return cancellingReadCloser{tickingReadCloser{r, t}, t} +} diff --git a/httpclient/timeout_context_test.go b/httpclient/timeout_context_test.go new file mode 100644 index 000000000..81d161ff8 --- /dev/null +++ b/httpclient/timeout_context_test.go @@ -0,0 +1,43 @@ +package httpclient + +import ( + "context" + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +func TestTimeoutContextTimeout(t *testing.T) { + ctx := context.Background() + ctx, _ = newTimeoutContext(ctx, time.Millisecond*50) + time.Sleep(time.Millisecond * 100) + + // The context should have timed out. + assert.Equal(t, context.Canceled, ctx.Err()) +} + +func TestTimeoutContextTick(t *testing.T) { + ctx := context.Background() + ctx, ticker := newTimeoutContext(ctx, time.Millisecond*50) + + // Extend the deadline a couple of times. + for i := 0; i < 5; i++ { + ticker.Tick() + time.Sleep(time.Millisecond * 25) + } + + // The context should not have timed out. + assert.Nil(t, ctx.Err()) +} + +func TestTimeoutContextCancel(t *testing.T) { + ctx := context.Background() + ctx, ticker := newTimeoutContext(ctx, time.Millisecond*50) + + // Cancel the context. + ticker.Cancel() + + // The context should have timed out. + assert.Equal(t, context.Canceled, ctx.Err()) +}