diff --git a/pkg/frontend/querymiddleware/shard_active_series.go b/pkg/frontend/querymiddleware/shard_active_series.go index 9af7b89ee3b..3dc2bf72005 100644 --- a/pkg/frontend/querymiddleware/shard_active_series.go +++ b/pkg/frontend/querymiddleware/shard_active_series.go @@ -52,6 +52,18 @@ var ( } ) +var snappyWriterPool sync.Pool + +func getSnappyWriter(w io.Writer) *s2.Writer { + sw := snappyWriterPool.Get() + if sw == nil { + return s2.NewWriter(w) + } + enc := sw.(*s2.Writer) + enc.Reset(w) + return enc +} + type shardActiveSeriesMiddleware struct { upstream http.RoundTripper limits Limits @@ -257,7 +269,7 @@ func shardedSelector(shardCount, currentShard int, expr parser.Expr) (parser.Exp }, nil } -func (s *shardActiveSeriesMiddleware) mergeResponses(ctx context.Context, responses []*http.Response, acceptEncoding string) *http.Response { +func (s *shardActiveSeriesMiddleware) mergeResponses(ctx context.Context, responses []*http.Response, encoding string) *http.Response { reader, writer := io.Pipe() items := make(chan *labels.Builder, len(responses)) @@ -326,29 +338,34 @@ func (s *shardActiveSeriesMiddleware) mergeResponses(ctx context.Context, respon close(items) }() - response := &http.Response{Body: reader, StatusCode: http.StatusOK, Header: http.Header{}} - response.Header.Set("Content-Type", "application/json") - if acceptEncoding == encodingTypeSnappyFramed { - response.Header.Set("Content-Encoding", encodingTypeSnappyFramed) + resp := &http.Response{Body: reader, StatusCode: http.StatusOK, Header: http.Header{}} + resp.Header.Set("Content-Type", "application/json") + if encoding == encodingTypeSnappyFramed { + resp.Header.Set("Content-Encoding", encodingTypeSnappyFramed) } - go s.writeMergedResponse(ctx, g.Wait, writer, items, acceptEncoding) + go s.writeMergedResponse(ctx, g.Wait, writer, items, encoding) - return response + return resp } -func (s *shardActiveSeriesMiddleware) writeMergedResponse(ctx context.Context, check func() error, w io.WriteCloser, items <-chan *labels.Builder, encodingType string) { +func (s *shardActiveSeriesMiddleware) writeMergedResponse(ctx context.Context, check func() error, w io.WriteCloser, items <-chan *labels.Builder, encoding string) { defer w.Close() span, _ := opentracing.StartSpanFromContext(ctx, "shardActiveSeries.writeMergedResponse") defer span.Finish() var out io.Writer = w - if encodingType == encodingTypeSnappyFramed { + if encoding == encodingTypeSnappyFramed { span.LogFields(otlog.String("encoding", encodingTypeSnappyFramed)) - enc := s2.NewWriter(w) - defer enc.Close() + enc := getSnappyWriter(w) out = enc + defer func() { + enc.Close() + // Reset the encoder before putting it back to pool to avoid it to hold the writer. + enc.Reset(nil) + snappyWriterPool.Put(enc) + }() } else { span.LogFields(otlog.String("encoding", "none")) } diff --git a/pkg/frontend/querymiddleware/shard_active_series_test.go b/pkg/frontend/querymiddleware/shard_active_series_test.go index 71b8babe695..fbb29d9e586 100644 --- a/pkg/frontend/querymiddleware/shard_active_series_test.go +++ b/pkg/frontend/querymiddleware/shard_active_series_test.go @@ -12,6 +12,7 @@ import ( "net/url" "strconv" "strings" + "sync" "testing" "github.com/go-kit/log" @@ -270,12 +271,6 @@ func Test_shardActiveSeriesMiddleware_RoundTrip(t *testing.T) { // Stub upstream with valid or invalid responses. var requestCount atomic.Int32 upstream := RoundTripFunc(func(r *http.Request) (*http.Response, error) { - defer func(body io.ReadCloser) { - if body != nil { - _ = body.Close() - } - }(r.Body) - _, _, err := user.ExtractOrgIDFromHTTPRequest(r) require.NoError(t, err) _, err = user.ExtractOrgID(r.Context()) @@ -358,7 +353,85 @@ func Test_shardActiveSeriesMiddleware_RoundTrip(t *testing.T) { } } +func Test_shardActiveSeriesMiddleware_RoundTrip_concurrent(t *testing.T) { + const shardCount = 4 + + upstream := RoundTripFunc(func(r *http.Request) (*http.Response, error) { + require.NoError(t, r.ParseForm()) + req, err := cardinality.DecodeActiveSeriesRequestFromValues(r.Form) + require.NoError(t, err) + shard, _, err := sharding.ShardFromMatchers(req.Matchers) + require.NoError(t, err) + require.NotNil(t, shard) + + resp := fmt.Sprintf(`{"data": [{"__name__": "metric-%d"}]}`, shard.ShardIndex) + + return &http.Response{StatusCode: http.StatusOK, Body: io.NopCloser(strings.NewReader(resp))}, nil + }) + + s := newShardActiveSeriesMiddleware( + upstream, + mockLimits{maxShardedQueries: shardCount, totalShards: shardCount}, + log.NewNopLogger(), + ) + + assertRoundTrip := func(t *testing.T, trip http.RoundTripper, req *http.Request) { + resp, err := trip.RoundTrip(req) + require.NoError(t, err) + defer resp.Body.Close() + + assert.Equal(t, http.StatusOK, resp.StatusCode) + + var body io.Reader = resp.Body + if resp.Header.Get("Content-Encoding") == encodingTypeSnappyFramed { + body = s2.NewReader(resp.Body) + } + + // For this test, if we can decode the response, it is enough to guaranty it worked. We proof actual validity + // of all kinds of responses in the tests above. + var res result + err = json.NewDecoder(body).Decode(&res) + require.NoError(t, err) + require.Len(t, res.Data, shardCount) + } + + const reqCount = 20 + + var wg sync.WaitGroup + defer wg.Wait() + + wg.Add(reqCount) + + for n := reqCount; n > 0; n-- { + go func(n int) { + defer wg.Done() + + req := httptest.NewRequest("POST", "/active_series", strings.NewReader(`selector={__name__=~"metric-.*"}`)) + req.Header.Add("Content-Type", "application/x-www-form-urlencoded") + + // Send every other request as snappy to proof the middleware doesn't mess up body encoders + if n%2 == 0 { + req.Header.Add("Accept-Encoding", encodingTypeSnappyFramed) + } + + req = req.WithContext(user.InjectOrgID(req.Context(), "test")) + + assertRoundTrip(t, s, req) + }(n) + } +} + func BenchmarkActiveSeriesMiddlewareMergeResponses(b *testing.B) { + b.Run("encoding=none", func(b *testing.B) { + benchmarkActiveSeriesMiddlewareMergeResponses(b, "") + }) + + b.Run("encoding=snappy", func(b *testing.B) { + benchmarkActiveSeriesMiddlewareMergeResponses(b, encodingTypeSnappyFramed) + }) +} + +func benchmarkActiveSeriesMiddlewareMergeResponses(b *testing.B, encoding string) { type activeSeriesResponse struct { Data []labels.Labels `json:"data"` } @@ -392,7 +465,7 @@ func BenchmarkActiveSeriesMiddlewareMergeResponses(b *testing.B) { b.ReportAllocs() for i := 0; i < b.N; i++ { - resp := s.mergeResponses(context.Background(), benchResponses[i], "") + resp := s.mergeResponses(context.Background(), benchResponses[i], encoding) _, _ = io.Copy(io.Discard, resp.Body) _ = resp.Body.Close()