diff --git a/br/pkg/pdutil/main_test.go b/br/pkg/pdutil/main_test.go new file mode 100644 index 0000000000000..861c3921a3eb3 --- /dev/null +++ b/br/pkg/pdutil/main_test.go @@ -0,0 +1,31 @@ +// Copyright 2021 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package pdutil + +import ( + "testing" + + "github.com/pingcap/tidb/util/testbridge" + "go.uber.org/goleak" +) + +func TestMain(m *testing.M) { + testbridge.WorkaroundGoCheckFlags() + opts := []goleak.Option{ + goleak.IgnoreTopFunction("go.etcd.io/etcd/pkg/logutil.(*MergeLogger).outputLoop"), + goleak.IgnoreTopFunction("go.opencensus.io/stats/view.(*worker).start"), + } + goleak.VerifyTestMain(m, opts...) +} diff --git a/br/pkg/pdutil/pd.go b/br/pkg/pdutil/pd.go index 5610578c0f766..2f898d9c062ef 100644 --- a/br/pkg/pdutil/pd.go +++ b/br/pkg/pdutil/pd.go @@ -156,14 +156,17 @@ func pdRequest( if count > pdRequestRetryTime || resp.StatusCode < 500 { break } - resp.Body.Close() - time.Sleep(time.Second) + _ = resp.Body.Close() + time.Sleep(pdRequestRetryInterval()) resp, err = cli.Do(req) if err != nil { return nil, errors.Trace(err) } } - defer resp.Body.Close() + defer func() { + _ = resp.Body.Close() + }() + if resp.StatusCode != http.StatusOK { res, _ := io.ReadAll(resp.Body) return nil, errors.Annotatef(berrors.ErrPDInvalidResponse, "[%d] %s %s", resp.StatusCode, res, reqURL) @@ -176,6 +179,15 @@ func pdRequest( return r, nil } +func pdRequestRetryInterval() time.Duration { + failpoint.Inject("FastRetry", func(v failpoint.Value) { + if v.(bool) { + failpoint.Return(0) + } + }) + return time.Second +} + // PdController manage get/update config from pd. type PdController struct { addrs []string diff --git a/br/pkg/pdutil/pd_test.go b/br/pkg/pdutil/pd_serial_test.go similarity index 72% rename from br/pkg/pdutil/pd_test.go rename to br/pkg/pdutil/pd_serial_test.go index e4e82d412171c..2dde535cd54b9 100644 --- a/br/pkg/pdutil/pd_test.go +++ b/br/pkg/pdutil/pd_serial_test.go @@ -15,26 +15,19 @@ import ( "testing" "github.com/coreos/go-semver/semver" - . "github.com/pingcap/check" + "github.com/pingcap/failpoint" "github.com/pingcap/kvproto/pkg/metapb" "github.com/pingcap/tidb/util/codec" + "github.com/stretchr/testify/require" "github.com/tikv/pd/pkg/typeutil" "github.com/tikv/pd/server/api" "github.com/tikv/pd/server/core" "github.com/tikv/pd/server/statistics" ) -func TestT(t *testing.T) { - TestingT(t) -} - -type testPDControllerSuite struct { -} - -var _ = Suite(&testPDControllerSuite{}) - -func (s *testPDControllerSuite) TestScheduler(c *C) { - ctx := context.Background() +func TestScheduler(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() scheduler := "balance-leader-scheduler" mock := func(context.Context, string, string, *http.Client, string, io.Reader) ([]byte, error) { @@ -44,13 +37,13 @@ func (s *testPDControllerSuite) TestScheduler(c *C) { pdController := &PdController{addrs: []string{"", ""}, schedulerPauseCh: schedulerPauseCh} _, err := pdController.pauseSchedulersAndConfigWith(ctx, []string{scheduler}, nil, mock) - c.Assert(err, ErrorMatches, "failed") + require.EqualError(t, err, "failed") go func() { <-schedulerPauseCh }() err = pdController.resumeSchedulerWith(ctx, []string{scheduler}, mock) - c.Assert(err, IsNil) + require.NoError(t, err) cfg := map[string]interface{}{ "max-merge-region-keys": 0, @@ -59,34 +52,37 @@ func (s *testPDControllerSuite) TestScheduler(c *C) { "max-pending-peer-count": uint64(16), } _, err = pdController.pauseSchedulersAndConfigWith(ctx, []string{}, cfg, mock) - c.Assert(err, ErrorMatches, "failed to update PD.*") + require.Error(t, err) + require.Regexp(t, "^failed to update PD.*", err.Error()) go func() { <-schedulerPauseCh }() + err = pdController.resumeSchedulerWith(ctx, []string{scheduler}, mock) + require.NoError(t, err) _, err = pdController.listSchedulersWith(ctx, mock) - c.Assert(err, ErrorMatches, "failed") + require.EqualError(t, err, "failed") mock = func(context.Context, string, string, *http.Client, string, io.Reader) ([]byte, error) { return []byte(`["` + scheduler + `"]`), nil } _, err = pdController.pauseSchedulersAndConfigWith(ctx, []string{scheduler}, cfg, mock) - c.Assert(err, IsNil) + require.NoError(t, err) go func() { <-schedulerPauseCh }() err = pdController.resumeSchedulerWith(ctx, []string{scheduler}, mock) - c.Assert(err, IsNil) + require.NoError(t, err) schedulers, err := pdController.listSchedulersWith(ctx, mock) - c.Assert(err, IsNil) - c.Assert(schedulers, HasLen, 1) - c.Assert(schedulers[0], Equals, scheduler) + require.NoError(t, err) + require.Len(t, schedulers, 1) + require.Equal(t, scheduler, schedulers[0]) } -func (s *testPDControllerSuite) TestGetClusterVersion(c *C) { +func TestGetClusterVersion(t *testing.T) { pdController := &PdController{addrs: []string{"", ""}} // two endpoints counter := 0 mock := func(context.Context, string, string, *http.Client, string, io.Reader) ([]byte, error) { @@ -99,17 +95,17 @@ func (s *testPDControllerSuite) TestGetClusterVersion(c *C) { ctx := context.Background() respString, err := pdController.getClusterVersionWith(ctx, mock) - c.Assert(err, IsNil) - c.Assert(respString, Equals, "test") + require.NoError(t, err) + require.Equal(t, "test", respString) mock = func(context.Context, string, string, *http.Client, string, io.Reader) ([]byte, error) { return nil, errors.New("mock error") } _, err = pdController.getClusterVersionWith(ctx, mock) - c.Assert(err, NotNil) + require.Error(t, err) } -func (s *testPDControllerSuite) TestRegionCount(c *C) { +func TestRegionCount(t *testing.T) { regions := core.NewRegionsInfo() regions.SetRegion(core.NewRegionInfo(&metapb.Region{ Id: 1, @@ -129,55 +125,61 @@ func (s *testPDControllerSuite) TestRegionCount(c *C) { EndKey: codec.EncodeBytes(nil, []byte{3, 4}), RegionEpoch: &metapb.RegionEpoch{}, }, nil)) - c.Assert(regions.Len(), Equals, 3) + require.Equal(t, 3, regions.Len()) mock := func( _ context.Context, addr string, prefix string, _ *http.Client, _ string, _ io.Reader, ) ([]byte, error) { query := fmt.Sprintf("%s/%s", addr, prefix) u, e := url.Parse(query) - c.Assert(e, IsNil, Commentf("%s", query)) + require.NoError(t, e, query) start := u.Query().Get("start_key") end := u.Query().Get("end_key") - c.Log(hex.EncodeToString([]byte(start))) - c.Log(hex.EncodeToString([]byte(end))) + t.Log(hex.EncodeToString([]byte(start))) + t.Log(hex.EncodeToString([]byte(end))) scanRegions := regions.ScanRange([]byte(start), []byte(end), 0) stats := statistics.RegionStats{Count: len(scanRegions)} ret, err := json.Marshal(stats) - c.Assert(err, IsNil) + require.NoError(t, err) return ret, nil } pdController := &PdController{addrs: []string{"http://mock"}} ctx := context.Background() resp, err := pdController.getRegionCountWith(ctx, mock, []byte{}, []byte{}) - c.Assert(err, IsNil) - c.Assert(resp, Equals, 3) + require.NoError(t, err) + require.Equal(t, 3, resp) resp, err = pdController.getRegionCountWith(ctx, mock, []byte{0}, []byte{0xff}) - c.Assert(err, IsNil) - c.Assert(resp, Equals, 3) + require.NoError(t, err) + require.Equal(t, 3, resp) resp, err = pdController.getRegionCountWith(ctx, mock, []byte{1, 2}, []byte{1, 4}) - c.Assert(err, IsNil) - c.Assert(resp, Equals, 2) + require.NoError(t, err) + require.Equal(t, 2, resp) } -func (s *testPDControllerSuite) TestPDVersion(c *C) { +func TestPDVersion(t *testing.T) { v := []byte("\"v4.1.0-alpha1\"\n") r := parseVersion(v) expectV := semver.New("4.1.0-alpha1") - c.Assert(r.Major, Equals, expectV.Major) - c.Assert(r.Minor, Equals, expectV.Minor) - c.Assert(r.PreRelease, Equals, expectV.PreRelease) + require.Equal(t, expectV.Major, r.Major) + require.Equal(t, expectV.Minor, r.Minor) + require.Equal(t, expectV.PreRelease, r.PreRelease) } -func (s *testPDControllerSuite) TestPDRequestRetry(c *C) { +func TestPDRequestRetry(t *testing.T) { ctx := context.Background() + + require.NoError(t, failpoint.Enable("github.com/pingcap/tidb/br/pkg/pdutil/FastRetry", "return(true)")) + defer func() { + require.NoError(t, failpoint.Disable("github.com/pingcap/tidb/br/pkg/pdutil/FastRetry")) + }() + count := 0 ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { count++ - if count <= 5 { + if count <= pdRequestRetryTime-1 { w.WriteHeader(http.StatusGatewayTimeout) return } @@ -186,12 +188,12 @@ func (s *testPDControllerSuite) TestPDRequestRetry(c *C) { cli := http.DefaultClient taddr := ts.URL _, reqErr := pdRequest(ctx, taddr, "", cli, http.MethodGet, nil) - c.Assert(reqErr, IsNil) + require.NoError(t, reqErr) ts.Close() count = 0 ts = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { count++ - if count <= 11 { + if count <= pdRequestRetryTime+1 { w.WriteHeader(http.StatusGatewayTimeout) return } @@ -200,10 +202,10 @@ func (s *testPDControllerSuite) TestPDRequestRetry(c *C) { defer ts.Close() taddr = ts.URL _, reqErr = pdRequest(ctx, taddr, "", cli, http.MethodGet, nil) - c.Assert(reqErr, NotNil) + require.Error(t, reqErr) } -func (s *testPDControllerSuite) TestStoreInfo(c *C) { +func TestStoreInfo(t *testing.T) { storeInfo := api.StoreInfo{ Status: &api.StoreStatus{ Capacity: typeutil.ByteSize(1024), @@ -217,18 +219,18 @@ func (s *testPDControllerSuite) TestStoreInfo(c *C) { _ context.Context, addr string, prefix string, _ *http.Client, _ string, _ io.Reader, ) ([]byte, error) { query := fmt.Sprintf("%s/%s", addr, prefix) - c.Assert(query, Equals, "http://mock/pd/api/v1/store/1") + require.Equal(t, "http://mock/pd/api/v1/store/1", query) ret, err := json.Marshal(storeInfo) - c.Assert(err, IsNil) + require.NoError(t, err) return ret, nil } pdController := &PdController{addrs: []string{"http://mock"}} ctx := context.Background() resp, err := pdController.getStoreInfoWith(ctx, mock, 1) - c.Assert(err, IsNil) - c.Assert(resp, NotNil) - c.Assert(resp.Status, NotNil) - c.Assert(resp.Store.StateName, Equals, "Tombstone") - c.Assert(uint64(resp.Status.Available), Equals, uint64(1024)) + require.NoError(t, err) + require.NotNil(t, resp) + require.NotNil(t, resp.Status) + require.Equal(t, "Tombstone", resp.Store.StateName) + require.Equal(t, uint64(1024), uint64(resp.Status.Available)) }