From b5dc2a50590a421d0878e33564195855e2f666d5 Mon Sep 17 00:00:00 2001 From: yedamo Date: Sat, 6 Nov 2021 23:43:05 +0800 Subject: [PATCH 01/55] server: migrate `TestStatusAPIWithTLS` --- server/server_test.go | 42 ++++++++++++++++++++++ server/tidb_test.go | 84 +++++++++++++++++++++++++++++++++++++++---- 2 files changed, 119 insertions(+), 7 deletions(-) diff --git a/server/server_test.go b/server/server_test.go index 82e4dd6ab883e..555e4b01d9f77 100644 --- a/server/server_test.go +++ b/server/server_test.go @@ -39,7 +39,9 @@ import ( "github.com/pingcap/tidb/errno" "github.com/pingcap/tidb/kv" tmysql "github.com/pingcap/tidb/parser/mysql" + "github.com/pingcap/tidb/testkit" "github.com/pingcap/tidb/util/versioninfo" + "github.com/stretchr/testify/require" "go.uber.org/zap" ) @@ -71,6 +73,19 @@ func newTestServerClient() *testServerClient { } } +type testingServerClient struct { + testServerClient +} + +// newTestServerClient return a testingServerClient with unique address +func newTestingServerClient() *testingServerClient { + return &testingServerClient{testServerClient{ + port: 0, + statusPort: 0, + statusScheme: "http", + }} +} + // statusURL return the full URL of a status path func (cli *testServerClient) statusURL(path string) string { return fmt.Sprintf("%s://localhost:%d%s", cli.statusScheme, cli.statusPort, path) @@ -122,6 +137,21 @@ func (cli *testServerClient) runTests(c *C, overrider configOverrider, tests ... } } +// runTests runs tests using the default database `test`. +func (cli *testingServerClient) runTests(t *testing.T, overrider configOverrider, tests ...func(dbt *testkit.DBTestKit)) { + db, err := sql.Open("mysql", cli.getDSN(overrider)) + require.NoErrorf(t, err, "Error connecting") + defer func() { + err := db.Close() + require.NoError(t, err) + }() + + dbt := testkit.NewDBTestKit(t, db) + for _, test := range tests { + test(dbt) + } +} + // runTestsOnNewDB runs tests using a specified database which will be created before the test and destroyed after the test. func (cli *testServerClient) runTestsOnNewDB(c *C, overrider configOverrider, dbName string, tests ...func(dbt *DBTest)) { dsn := cli.getDSN(overrider, func(config *mysql.Config) { @@ -1700,6 +1730,18 @@ func (cli *testServerClient) runTestStatusAPI(c *C) { c.Assert(data.GitHash, Equals, versioninfo.TiDBGitHash) } +func (cli *testingServerClient) runTestStatusAPI(t *testing.T) { + resp, err := cli.fetchStatus("/status") + require.NoError(t, err) + defer resp.Body.Close() + decoder := json.NewDecoder(resp.Body) + var data status + err = decoder.Decode(&data) + require.NoError(t, err) + require.Equal(t, tmysql.ServerVersion, data.Version) + require.Equal(t, versioninfo.TiDBGitHash, data.GitHash) +} + // The golang sql driver (and most drivers) should have multi-statement // disabled by default for security reasons. Lets ensure that the behavior // is correct. diff --git a/server/tidb_test.go b/server/tidb_test.go index b9f9ddbf9f9e3..74a91dc3df550 100644 --- a/server/tidb_test.go +++ b/server/tidb_test.go @@ -33,6 +33,7 @@ import ( "path/filepath" "strings" "sync/atomic" + "testing" "time" "github.com/go-sql-driver/mysql" @@ -57,12 +58,78 @@ import ( mockTopSQLReporter "github.com/pingcap/tidb/util/topsql/reporter/mock" "github.com/pingcap/tidb/util/topsql/tracecpu" mockTopSQLTraceCPU "github.com/pingcap/tidb/util/topsql/tracecpu/mock" + "github.com/stretchr/testify/require" ) +type tidbTestBase struct { + *testingServerClient + tidbdrv *TiDBDriver + server *Server + domain *domain.Domain + store kv.Storage +} + +func createTiDBTestBase(t *testing.T) (*tidbTestBase, func()) { + ts := &tidbTestBase{testingServerClient: newTestingServerClient()} + + // setup tidbTestBase + var err error + ts.store, err = mockstore.NewMockStore() + session.DisableStats4Test() + require.NoError(t, err) + ts.domain, err = session.BootstrapSession(ts.store) + require.NoError(t, err) + ts.tidbdrv = NewTiDBDriver(ts.store) + cfg := newTestConfig() + cfg.Socket = "" + cfg.Port = ts.port + cfg.Status.ReportStatus = true + cfg.Status.StatusPort = ts.statusPort + cfg.Performance.TCPKeepAlive = true + err = logutil.InitLogger(cfg.Log.ToLogConfig()) + require.NoError(t, err) + + server, err := NewServer(cfg, ts.tidbdrv) + require.NoError(t, err) + ts.port = getPortFromTCPAddr(server.listener.Addr()) + ts.statusPort = getPortFromTCPAddr(server.statusListener.Addr()) + ts.server = server + go func() { + err := ts.server.Run() + require.NoError(t, err) + }() + ts.waitUntilServerOnline() + + cleanup := func() { + if ts.store != nil { + ts.store.Close() + } + if ts.domain != nil { + ts.domain.Close() + } + if ts.server != nil { + ts.server.Close() + } + } + + return ts, cleanup +} + type tidbTestSuite struct { *tidbTestSuiteBase } +type tidbTest struct { + *tidbTestBase +} + +func createTiDBTest(t *testing.T) (*tidbTest, func()) { + base, cleanup := createTiDBTestBase(t) + // TODO: register metrics + // metrics.RegisterMetrics() + return &tidbTest{base}, cleanup +} + type tidbTestSerialSuite struct { *tidbTestSuiteBase } @@ -287,11 +354,14 @@ func (ts *tidbTestSuite) TestStatusPort(c *C) { c.Assert(server, IsNil) } -func (ts *tidbTestSuite) TestStatusAPIWithTLS(c *C) { +func TestStatusAPIWithTLS(t *testing.T) { + ts, cleanup := createTiDBTest(t) + defer cleanup() + caCert, caKey, err := generateCert(0, "TiDB CA 2", nil, nil, "/tmp/ca-key-2.pem", "/tmp/ca-cert-2.pem") - c.Assert(err, IsNil) + require.NoError(t, err) _, _, err = generateCert(1, "tidb-server-2", caCert, caKey, "/tmp/server-key-2.pem", "/tmp/server-cert-2.pem") - c.Assert(err, IsNil) + require.NoError(t, err) defer func() { os.Remove("/tmp/ca-key-2.pem") @@ -310,22 +380,22 @@ func (ts *tidbTestSuite) TestStatusAPIWithTLS(c *C) { cfg.Security.ClusterSSLCert = "/tmp/server-cert-2.pem" cfg.Security.ClusterSSLKey = "/tmp/server-key-2.pem" server, err := NewServer(cfg, ts.tidbdrv) - c.Assert(err, IsNil) + require.NoError(t, err) cli.port = getPortFromTCPAddr(server.listener.Addr()) cli.statusPort = getPortFromTCPAddr(server.statusListener.Addr()) go func() { err := server.Run() - c.Assert(err, IsNil) + require.NoError(t, err) }() time.Sleep(time.Millisecond * 100) // https connection should work. - ts.runTestStatusAPI(c) + ts.runTestStatusAPI(t) // but plain http connection should fail. cli.statusScheme = "http" _, err = cli.fetchStatus("/status") // nolint: bodyclose - c.Assert(err, NotNil) + require.Error(t, err) server.Close() } From 34181880baccdaac44a19c18ad71b3e04713ab92 Mon Sep 17 00:00:00 2001 From: yedamo Date: Sat, 6 Nov 2021 23:50:32 +0800 Subject: [PATCH 02/55] server: migrate `TestStatusPort` --- server/tidb_test.go | 17 ++++++----------- 1 file changed, 6 insertions(+), 11 deletions(-) diff --git a/server/tidb_test.go b/server/tidb_test.go index 74a91dc3df550..cc6e389c29b0a 100644 --- a/server/tidb_test.go +++ b/server/tidb_test.go @@ -333,15 +333,10 @@ func (ts *tidbTestSuite) TestStatusAPI(c *C) { ts.runTestStatusAPI(c) } -func (ts *tidbTestSuite) TestStatusPort(c *C) { - store, err := mockstore.NewMockStore() - c.Assert(err, IsNil) - defer store.Close() - session.DisableStats4Test() - dom, err := session.BootstrapSession(store) - c.Assert(err, IsNil) - defer dom.Close() - ts.tidbdrv = NewTiDBDriver(store) +func TestStatusPort(t *testing.T) { + ts, cleanup := createTiDBTest(t) + defer cleanup() + cfg := newTestConfig() cfg.Socket = "" cfg.Port = 0 @@ -350,8 +345,8 @@ func (ts *tidbTestSuite) TestStatusPort(c *C) { cfg.Performance.TCPKeepAlive = true server, err := NewServer(cfg, ts.tidbdrv) - c.Assert(err, NotNil) - c.Assert(server, IsNil) + require.Error(t, err) + require.Nil(t, server) } func TestStatusAPIWithTLS(t *testing.T) { From ea496ec5adbc0800617a598374d591a2f4b655f4 Mon Sep 17 00:00:00 2001 From: yedamo Date: Sat, 6 Nov 2021 23:57:28 +0800 Subject: [PATCH 03/55] server: migrate `TestStatusAPIWithTLSCNCheck` --- server/tidb_test.go | 37 ++++++++++++++++++++++--------------- 1 file changed, 22 insertions(+), 15 deletions(-) diff --git a/server/tidb_test.go b/server/tidb_test.go index cc6e389c29b0a..2f6678a07d326 100644 --- a/server/tidb_test.go +++ b/server/tidb_test.go @@ -334,6 +334,7 @@ func (ts *tidbTestSuite) TestStatusAPI(c *C) { } func TestStatusPort(t *testing.T) { + t.Parallel() ts, cleanup := createTiDBTest(t) defer cleanup() @@ -350,6 +351,7 @@ func TestStatusPort(t *testing.T) { } func TestStatusAPIWithTLS(t *testing.T) { + t.Parallel() ts, cleanup := createTiDBTest(t) defer cleanup() @@ -395,7 +397,11 @@ func TestStatusAPIWithTLS(t *testing.T) { server.Close() } -func (ts *tidbTestSuite) TestStatusAPIWithTLSCNCheck(c *C) { +func TestStatusAPIWithTLSCNCheck(t *testing.T) { + t.Parallel() + ts, cleanup := createTiDBTest(t) + defer cleanup() + caPath := filepath.Join(os.TempDir(), "ca-cert-cn.pem") serverKeyPath := filepath.Join(os.TempDir(), "server-key-cn.pem") serverCertPath := filepath.Join(os.TempDir(), "server-cert-cn.pem") @@ -405,17 +411,17 @@ func (ts *tidbTestSuite) TestStatusAPIWithTLSCNCheck(c *C) { client2CertPath := filepath.Join(os.TempDir(), "client-cert-cn-check-b.pem") caCert, caKey, err := generateCert(0, "TiDB CA CN CHECK", nil, nil, filepath.Join(os.TempDir(), "ca-key-cn.pem"), caPath) - c.Assert(err, IsNil) + require.NoError(t, err) _, _, err = generateCert(1, "tidb-server-cn-check", caCert, caKey, serverKeyPath, serverCertPath) - c.Assert(err, IsNil) + require.NoError(t, err) _, _, err = generateCert(2, "tidb-client-cn-check-a", caCert, caKey, client1KeyPath, client1CertPath, func(c *x509.Certificate) { c.Subject.CommonName = "tidb-client-1" }) - c.Assert(err, IsNil) + require.NoError(t, err) _, _, err = generateCert(3, "tidb-client-cn-check-b", caCert, caKey, client2KeyPath, client2CertPath, func(c *x509.Certificate) { c.Subject.CommonName = "tidb-client-2" }) - c.Assert(err, IsNil) + require.NoError(t, err) cli := newTestServerClient() cli.statusScheme = "https" @@ -428,37 +434,38 @@ func (ts *tidbTestSuite) TestStatusAPIWithTLSCNCheck(c *C) { cfg.Security.ClusterSSLKey = serverKeyPath cfg.Security.ClusterVerifyCN = []string{"tidb-client-2"} server, err := NewServer(cfg, ts.tidbdrv) - c.Assert(err, IsNil) + require.NoError(t, err) + cli.port = getPortFromTCPAddr(server.listener.Addr()) cli.statusPort = getPortFromTCPAddr(server.statusListener.Addr()) go func() { err := server.Run() - c.Assert(err, IsNil) + require.NoError(t, err) }() defer server.Close() time.Sleep(time.Millisecond * 100) - hc := newTLSHttpClient(c, caPath, + hc := newTLSHttpClient(t, caPath, client1CertPath, client1KeyPath, ) _, err = hc.Get(cli.statusURL("/status")) // nolint: bodyclose - c.Assert(err, NotNil) + require.Error(t, err) - hc = newTLSHttpClient(c, caPath, + hc = newTLSHttpClient(t, caPath, client2CertPath, client2KeyPath, ) resp, err := hc.Get(cli.statusURL("/status")) - c.Assert(err, IsNil) - c.Assert(resp.Body.Close(), IsNil) + require.NoError(t, err) + require.Nil(t, resp.Body.Close()) } -func newTLSHttpClient(c *C, caFile, certFile, keyFile string) *http.Client { +func newTLSHttpClient(t *testing.T, caFile, certFile, keyFile string) *http.Client { cert, err := tls.LoadX509KeyPair(certFile, keyFile) - c.Assert(err, IsNil) + require.NoError(t, err) caCert, err := os.ReadFile(caFile) - c.Assert(err, IsNil) + require.NoError(t, err) caCertPool := x509.NewCertPool() caCertPool.AppendCertsFromPEM(caCert) tlsConfig := &tls.Config{ From 3bc068b351e2181c7f8292244e2e7915c3ea66fb Mon Sep 17 00:00:00 2001 From: yedamo Date: Sun, 7 Nov 2021 11:50:07 +0800 Subject: [PATCH 04/55] server: migrate `TestPessimisticInsertSelectForUpdate` --- server/tidb_test.go | 24 ++++++++++++++---------- 1 file changed, 14 insertions(+), 10 deletions(-) diff --git a/server/tidb_test.go b/server/tidb_test.go index 2f6678a07d326..a065e2fa625e5 100644 --- a/server/tidb_test.go +++ b/server/tidb_test.go @@ -1641,26 +1641,30 @@ func (ts *tidbTestSerialSuite) TestDefaultCharacterAndCollation(c *C) { } } -func (ts *tidbTestSuite) TestPessimisticInsertSelectForUpdate(c *C) { +func TestPessimisticInsertSelectForUpdate(t *testing.T) { + t.Parallel() + ts, cleanup := createTiDBTest(t) + defer cleanup() + qctx, err := ts.tidbdrv.OpenCtx(uint64(0), 0, uint8(tmysql.DefaultCollationID), "test", nil) - c.Assert(err, IsNil) + require.NoError(t, err) defer qctx.Close() ctx := context.Background() _, err = Execute(ctx, qctx, "use test;") - c.Assert(err, IsNil) + require.NoError(t, err) _, err = Execute(ctx, qctx, "drop table if exists t1, t2") - c.Assert(err, IsNil) + require.NoError(t, err) _, err = Execute(ctx, qctx, "create table t1 (id int)") - c.Assert(err, IsNil) + require.NoError(t, err) _, err = Execute(ctx, qctx, "create table t2 (id int)") - c.Assert(err, IsNil) + require.NoError(t, err) _, err = Execute(ctx, qctx, "insert into t1 select 1") - c.Assert(err, IsNil) + require.NoError(t, err) _, err = Execute(ctx, qctx, "begin pessimistic") - c.Assert(err, IsNil) + require.NoError(t, err) rs, err := Execute(ctx, qctx, "INSERT INTO t2 (id) select id from t1 where id = 1 for update") - c.Assert(err, IsNil) - c.Assert(rs, IsNil) // should be no delay + require.NoError(t, err) + require.Nil(t, rs) // should be no delay } func (ts *tidbTestSerialSuite) TestPrepareCount(c *C) { From 25c212fe5dc0d1f8d630d669c652d6c861f6e812 Mon Sep 17 00:00:00 2001 From: yedamo Date: Sun, 7 Nov 2021 11:55:20 +0800 Subject: [PATCH 05/55] server: migrate `TestShowTablesFlen` --- server/tidb_test.go | 22 +++++++++++++--------- 1 file changed, 13 insertions(+), 9 deletions(-) diff --git a/server/tidb_test.go b/server/tidb_test.go index a065e2fa625e5..127514bfad0cd 100644 --- a/server/tidb_test.go +++ b/server/tidb_test.go @@ -1367,25 +1367,29 @@ func Execute(ctx context.Context, qc *TiDBContext, sql string) (ResultSet, error return qc.ExecuteStmt(ctx, stmts[0]) } -func (ts *tidbTestSuite) TestShowTablesFlen(c *C) { +func TestShowTablesFlen(t *testing.T) { + t.Parallel() + ts, cleanup := createTiDBTest(t) + defer cleanup() + qctx, err := ts.tidbdrv.OpenCtx(uint64(0), 0, uint8(tmysql.DefaultCollationID), "test", nil) - c.Assert(err, IsNil) + require.NoError(t, err) ctx := context.Background() _, err = Execute(ctx, qctx, "use test;") - c.Assert(err, IsNil) + require.NoError(t, err) testSQL := "create table abcdefghijklmnopqrstuvwxyz (i int)" _, err = Execute(ctx, qctx, testSQL) - c.Assert(err, IsNil) + require.NoError(t, err) rs, err := Execute(ctx, qctx, "show tables") - c.Assert(err, IsNil) + require.NoError(t, err) req := rs.NewChunk(nil) err = rs.Next(ctx, req) - c.Assert(err, IsNil) + require.NoError(t, err) cols := rs.Columns() - c.Assert(err, IsNil) - c.Assert(len(cols), Equals, 1) - c.Assert(int(cols[0].ColumnLength), Equals, 26*tmysql.MaxBytesOfCharacter) + require.NoError(t, err) + require.Len(t, cols, 1) + require.Equal(t, 26*tmysql.MaxBytesOfCharacter, int(cols[0].ColumnLength)) } func checkColNames(c *C, columns []*ColumnInfo, names ...string) { From abb324f58fa7da8f5cd9056c50cc9956af6942fd Mon Sep 17 00:00:00 2001 From: yedamo Date: Sun, 7 Nov 2021 12:09:04 +0800 Subject: [PATCH 06/55] server: migate `TestFieldList` --- server/tidb_test.go | 46 ++++++++++++++++++++++++--------------------- 1 file changed, 25 insertions(+), 21 deletions(-) diff --git a/server/tidb_test.go b/server/tidb_test.go index 127514bfad0cd..940c92644ec48 100644 --- a/server/tidb_test.go +++ b/server/tidb_test.go @@ -1392,18 +1392,22 @@ func TestShowTablesFlen(t *testing.T) { require.Equal(t, 26*tmysql.MaxBytesOfCharacter, int(cols[0].ColumnLength)) } -func checkColNames(c *C, columns []*ColumnInfo, names ...string) { +func checkColNames(t *testing.T, columns []*ColumnInfo, names ...string) { for i, name := range names { - c.Assert(columns[i].Name, Equals, name) - c.Assert(columns[i].OrgName, Equals, name) + require.Equal(t, name, columns[i].Name) + require.Equal(t, name, columns[i].OrgName) } } -func (ts *tidbTestSuite) TestFieldList(c *C) { +func TestFieldList(t *testing.T) { + t.Parallel() + ts, cleanup := createTiDBTest(t) + defer cleanup() + qctx, err := ts.tidbdrv.OpenCtx(uint64(0), 0, uint8(tmysql.DefaultCollationID), "test", nil) - c.Assert(err, IsNil) + require.NoError(t, err) _, err = Execute(context.Background(), qctx, "use test;") - c.Assert(err, IsNil) + require.NoError(t, err) ctx := context.Background() testSQL := `create table t ( @@ -1428,22 +1432,22 @@ func (ts *tidbTestSuite) TestFieldList(c *C) { c_year year )` _, err = Execute(ctx, qctx, testSQL) - c.Assert(err, IsNil) + require.NoError(t, err) colInfos, err := qctx.FieldList("t") - c.Assert(err, IsNil) - c.Assert(len(colInfos), Equals, 19) + require.NoError(t, err) + require.Len(t, colInfos, 19) - checkColNames(c, colInfos, "c_bit", "c_int_d", "c_bigint_d", "c_float_d", + checkColNames(t, colInfos, "c_bit", "c_int_d", "c_bigint_d", "c_float_d", "c_double_d", "c_decimal", "c_datetime", "c_time", "c_date", "c_timestamp", "c_char", "c_varchar", "c_text_d", "c_binary", "c_blob_d", "c_set", "c_enum", "c_json", "c_year") for _, cols := range colInfos { - c.Assert(cols.Schema, Equals, "test") + require.Equal(t, "test", cols.Schema) } for _, cols := range colInfos { - c.Assert(cols.Table, Equals, "t") + require.Equal(t, "t", cols.Table) } for i, col := range colInfos { @@ -1451,31 +1455,31 @@ func (ts *tidbTestSuite) TestFieldList(c *C) { case 10, 11, 12, 15, 16: // c_char char(20), c_varchar varchar(20), c_text_d text, // c_set set('a', 'b', 'c'), c_enum enum('a', 'b', 'c') - c.Assert(col.Charset, Equals, uint16(tmysql.CharsetNameToID(tmysql.DefaultCharset)), Commentf("index %d", i)) + require.Equalf(t, uint16(tmysql.CharsetNameToID(tmysql.DefaultCharset)), col.Charset, "index %d", i) continue } - c.Assert(col.Charset, Equals, uint16(tmysql.CharsetNameToID("binary")), Commentf("index %d", i)) + require.Equalf(t, uint16(tmysql.CharsetNameToID("binary")), col.Charset, "index %d", i) } // c_decimal decimal(6, 3) - c.Assert(colInfos[5].Decimal, Equals, uint8(3)) + require.Equal(t, uint8(3), colInfos[5].Decimal) // for issue#10513 tooLongColumnAsName := "COALESCE(0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0)" columnAsName := tooLongColumnAsName[:tmysql.MaxAliasIdentifierLen] rs, err := Execute(ctx, qctx, "select "+tooLongColumnAsName) - c.Assert(err, IsNil) + require.NoError(t, err) cols := rs.Columns() - c.Assert(cols[0].OrgName, Equals, tooLongColumnAsName) - c.Assert(cols[0].Name, Equals, columnAsName) + require.Equal(t, tooLongColumnAsName, cols[0].OrgName) + require.Equal(t, columnAsName, cols[0].Name) rs, err = Execute(ctx, qctx, "select c_bit as '"+tooLongColumnAsName+"' from t") - c.Assert(err, IsNil) + require.NoError(t, err) cols = rs.Columns() - c.Assert(cols[0].OrgName, Equals, "c_bit") - c.Assert(cols[0].Name, Equals, columnAsName) + require.Equal(t, "c_bit", cols[0].OrgName) + require.Equal(t, columnAsName, cols[0].Name) } func (ts *tidbTestSuite) TestClientErrors(c *C) { From db99c53ab9aff9f300305a80d921138bfd030649 Mon Sep 17 00:00:00 2001 From: yedamo Date: Sun, 7 Nov 2021 12:20:44 +0800 Subject: [PATCH 07/55] server: migrate `TestNullFlag` --- server/tidb_test.go | 48 ++++++++++++++++++++++++--------------------- 1 file changed, 26 insertions(+), 22 deletions(-) diff --git a/server/tidb_test.go b/server/tidb_test.go index 940c92644ec48..9e70874823de2 100644 --- a/server/tidb_test.go +++ b/server/tidb_test.go @@ -1495,68 +1495,72 @@ func (ts *tidbTestSuite) TestSumAvg(c *C) { ts.runTestSumAvg(c) } -func (ts *tidbTestSuite) TestNullFlag(c *C) { +func TestNullFlag(t *testing.T) { + t.Parallel() + ts, cleanup := createTiDBTest(t) + defer cleanup() + qctx, err := ts.tidbdrv.OpenCtx(uint64(0), 0, uint8(tmysql.DefaultCollationID), "test", nil) - c.Assert(err, IsNil) + require.NoError(t, err) ctx := context.Background() { // issue #9689 rs, err := Execute(ctx, qctx, "select 1") - c.Assert(err, IsNil) + require.NoError(t, err) cols := rs.Columns() - c.Assert(len(cols), Equals, 1) + require.Len(t, cols, 1) expectFlag := uint16(tmysql.NotNullFlag | tmysql.BinaryFlag) - c.Assert(dumpFlag(cols[0].Type, cols[0].Flag), Equals, expectFlag) + require.Equal(t, expectFlag, dumpFlag(cols[0].Type, cols[0].Flag)) } { // issue #19025 rs, err := Execute(ctx, qctx, "select convert('{}', JSON)") - c.Assert(err, IsNil) + require.NoError(t, err) cols := rs.Columns() - c.Assert(len(cols), Equals, 1) + require.Len(t, cols, 1) expectFlag := uint16(tmysql.BinaryFlag) - c.Assert(dumpFlag(cols[0].Type, cols[0].Flag), Equals, expectFlag) + require.Equal(t, expectFlag, dumpFlag(cols[0].Type, cols[0].Flag)) } { // issue #18488 _, err := Execute(ctx, qctx, "use test") - c.Assert(err, IsNil) + require.NoError(t, err) _, err = Execute(ctx, qctx, "CREATE TABLE `test` (`iD` bigint(20) NOT NULL, `INT_TEST` int(11) DEFAULT NULL);") - c.Assert(err, IsNil) + require.NoError(t, err) rs, err := Execute(ctx, qctx, `SELECT id + int_test as res FROM test GROUP BY res ORDER BY res;`) - c.Assert(err, IsNil) + require.NoError(t, err) cols := rs.Columns() - c.Assert(len(cols), Equals, 1) + require.Len(t, cols, 1) expectFlag := uint16(tmysql.BinaryFlag) - c.Assert(dumpFlag(cols[0].Type, cols[0].Flag), Equals, expectFlag) + require.Equal(t, expectFlag, dumpFlag(cols[0].Type, cols[0].Flag)) } { rs, err := Execute(ctx, qctx, "select if(1, null, 1) ;") - c.Assert(err, IsNil) + require.NoError(t, err) cols := rs.Columns() - c.Assert(len(cols), Equals, 1) + require.Len(t, cols, 1) expectFlag := uint16(tmysql.BinaryFlag) - c.Assert(dumpFlag(cols[0].Type, cols[0].Flag), Equals, expectFlag) + require.Equal(t, expectFlag, dumpFlag(cols[0].Type, cols[0].Flag)) } { rs, err := Execute(ctx, qctx, "select CASE 1 WHEN 2 THEN 1 END ;") - c.Assert(err, IsNil) + require.NoError(t, err) cols := rs.Columns() - c.Assert(len(cols), Equals, 1) + require.Len(t, cols, 1) expectFlag := uint16(tmysql.BinaryFlag) - c.Assert(dumpFlag(cols[0].Type, cols[0].Flag), Equals, expectFlag) + require.Equal(t, expectFlag, dumpFlag(cols[0].Type, cols[0].Flag)) } { rs, err := Execute(ctx, qctx, "select NULL;") - c.Assert(err, IsNil) + require.NoError(t, err) cols := rs.Columns() - c.Assert(len(cols), Equals, 1) + require.Len(t, cols, 1) expectFlag := uint16(tmysql.BinaryFlag) - c.Assert(dumpFlag(cols[0].Type, cols[0].Flag), Equals, expectFlag) + require.Equal(t, expectFlag, dumpFlag(cols[0].Type, cols[0].Flag)) } } From 651883e8a53ec7791530896c2b1348791b8a9c65 Mon Sep 17 00:00:00 2001 From: yedamo Date: Sun, 7 Nov 2021 12:25:57 +0800 Subject: [PATCH 08/55] server: migrate `TestNO_DEFAULT_VALUEFlag` --- server/tidb_test.go | 20 ++++++++++++-------- 1 file changed, 12 insertions(+), 8 deletions(-) diff --git a/server/tidb_test.go b/server/tidb_test.go index 9e70874823de2..c40f2ee4e70be 100644 --- a/server/tidb_test.go +++ b/server/tidb_test.go @@ -1564,24 +1564,28 @@ func TestNullFlag(t *testing.T) { } } -func (ts *tidbTestSuite) TestNO_DEFAULT_VALUEFlag(c *C) { +func TestNO_DEFAULT_VALUEFlag(t *testing.T) { + t.Parallel() + ts, cleanup := createTiDBTest(t) + defer cleanup() + // issue #21465 qctx, err := ts.tidbdrv.OpenCtx(uint64(0), 0, uint8(tmysql.DefaultCollationID), "test", nil) - c.Assert(err, IsNil) + require.NoError(t, err) ctx := context.Background() _, err = Execute(ctx, qctx, "use test") - c.Assert(err, IsNil) + require.NoError(t, err) _, err = Execute(ctx, qctx, "drop table if exists t") - c.Assert(err, IsNil) + require.NoError(t, err) _, err = Execute(ctx, qctx, "create table t(c1 int key, c2 int);") - c.Assert(err, IsNil) + require.NoError(t, err) rs, err := Execute(ctx, qctx, "select c1 from t;") - c.Assert(err, IsNil) + require.NoError(t, err) cols := rs.Columns() - c.Assert(len(cols), Equals, 1) + require.Len(t, cols, 1) expectFlag := uint16(tmysql.NotNullFlag | tmysql.PriKeyFlag | tmysql.NoDefaultValueFlag) - c.Assert(dumpFlag(cols[0].Type, cols[0].Flag), Equals, expectFlag) + require.Equal(t, expectFlag, dumpFlag(cols[0].Type, cols[0].Flag)) } func (ts *tidbTestSuite) TestGracefulShutdown(c *C) { From 1ef56e5805f5c6bd1ba103c9e011a7abb4064e75 Mon Sep 17 00:00:00 2001 From: yedamo Date: Sun, 7 Nov 2021 13:01:15 +0800 Subject: [PATCH 09/55] server: migrate `TestCreateTableFlen` --- server/tidb_test.go | 32 ++++++++++++++++++-------------- 1 file changed, 18 insertions(+), 14 deletions(-) diff --git a/server/tidb_test.go b/server/tidb_test.go index c40f2ee4e70be..2e59d6d51fd9d 100644 --- a/server/tidb_test.go +++ b/server/tidb_test.go @@ -1297,12 +1297,16 @@ func (ts *tidbTestSuite) TestClientWithCollation(c *C) { ts.runTestClientWithCollation(c) } -func (ts *tidbTestSuite) TestCreateTableFlen(c *C) { +func TestCreateTableFlen(t *testing.T) { + t.Parallel() + ts, cleanup := createTiDBTest(t) + defer cleanup() + // issue #4540 qctx, err := ts.tidbdrv.OpenCtx(uint64(0), 0, uint8(tmysql.DefaultCollationID), "test", nil) - c.Assert(err, IsNil) + require.NoError(t, err) _, err = Execute(context.Background(), qctx, "use test;") - c.Assert(err, IsNil) + require.NoError(t, err) ctx := context.Background() testSQL := "CREATE TABLE `t1` (" + @@ -1335,25 +1339,25 @@ func (ts *tidbTestSuite) TestCreateTableFlen(c *C) { "PRIMARY KEY (`a`)" + ") ENGINE=InnoDB DEFAULT CHARSET=utf8 COLLATE=utf8_bin" _, err = Execute(ctx, qctx, testSQL) - c.Assert(err, IsNil) + require.NoError(t, err) rs, err := Execute(ctx, qctx, "show create table t1") - c.Assert(err, IsNil) + require.NoError(t, err) req := rs.NewChunk(nil) err = rs.Next(ctx, req) - c.Assert(err, IsNil) + require.NoError(t, err) cols := rs.Columns() - c.Assert(err, IsNil) - c.Assert(len(cols), Equals, 2) - c.Assert(int(cols[0].ColumnLength), Equals, 5*tmysql.MaxBytesOfCharacter) - c.Assert(int(cols[1].ColumnLength), Equals, len(req.GetRow(0).GetString(1))*tmysql.MaxBytesOfCharacter) + require.NoError(t, err) + require.Len(t, cols, 2) + require.Equal(t, 5*tmysql.MaxBytesOfCharacter, int(cols[0].ColumnLength)) + require.Equal(t, len(req.GetRow(0).GetString(1))*tmysql.MaxBytesOfCharacter, int(cols[1].ColumnLength)) // for issue#5246 rs, err = Execute(ctx, qctx, "select y, z from t1") - c.Assert(err, IsNil) + require.NoError(t, err) cols = rs.Columns() - c.Assert(len(cols), Equals, 2) - c.Assert(int(cols[0].ColumnLength), Equals, 21) - c.Assert(int(cols[1].ColumnLength), Equals, 22) + require.Len(t, cols, 2) + require.Equal(t, 21, int(cols[0].ColumnLength)) + require.Equal(t, 22, int(cols[1].ColumnLength)) } func Execute(ctx context.Context, qc *TiDBContext, sql string) (ResultSet, error) { From a8ceb0966be06f98cf7ffb58ca14e51c5b210ee0 Mon Sep 17 00:00:00 2001 From: yedamo Date: Sun, 7 Nov 2021 13:07:45 +0800 Subject: [PATCH 10/55] server: migrate `TestSystemTimeZone` --- server/tidb_test.go | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/server/tidb_test.go b/server/tidb_test.go index 2e59d6d51fd9d..f6996b7b5ce8b 100644 --- a/server/tidb_test.go +++ b/server/tidb_test.go @@ -49,11 +49,11 @@ import ( "github.com/pingcap/tidb/session" "github.com/pingcap/tidb/sessionctx/variable" "github.com/pingcap/tidb/store/mockstore" + "github.com/pingcap/tidb/testkit" "github.com/pingcap/tidb/util" "github.com/pingcap/tidb/util/collate" "github.com/pingcap/tidb/util/logutil" "github.com/pingcap/tidb/util/plancodec" - "github.com/pingcap/tidb/util/testkit" "github.com/pingcap/tidb/util/topsql/reporter" mockTopSQLReporter "github.com/pingcap/tidb/util/topsql/reporter/mock" "github.com/pingcap/tidb/util/topsql/tracecpu" @@ -940,14 +940,18 @@ func registerTLSConfig(configName string, caCertPath string, clientCertPath stri return mysql.RegisterTLSConfig(configName, tlsConfig) } -func (ts *tidbTestSuite) TestSystemTimeZone(c *C) { - tk := testkit.NewTestKit(c, ts.store) +func TestSystemTimeZone(t *testing.T) { + t.Parallel() + ts, cleanup := createTiDBTest(t) + defer cleanup() + + tk := testkit.NewTestKit(t, ts.store) cfg := newTestConfig() cfg.Socket = "" cfg.Port, cfg.Status.StatusPort = 0, 0 cfg.Status.ReportStatus = false server, err := NewServer(cfg, ts.tidbdrv) - c.Assert(err, IsNil) + require.NoError(t, err) defer server.Close() tz1 := tk.MustQuery("select variable_value from mysql.tidb where variable_name = 'system_tz'").Rows() From 937568e912ed8fad57e96cabba3d0c47eb7d1b22 Mon Sep 17 00:00:00 2001 From: yedamo Date: Sun, 7 Nov 2021 14:49:53 +0800 Subject: [PATCH 11/55] server: migrate `TestAuth` --- server/server_test.go | 74 +++++++++++++++++++++---------------------- server/tidb_test.go | 11 ++++--- 2 files changed, 44 insertions(+), 41 deletions(-) diff --git a/server/server_test.go b/server/server_test.go index 555e4b01d9f77..2aeebb5694634 100644 --- a/server/server_test.go +++ b/server/server_test.go @@ -1563,57 +1563,57 @@ func checkErrorCode(c *C, e error, codes ...uint16) { c.Assert(isMatchCode, IsTrue, Commentf("got err %v, expected err codes %v", me, codes)) } -func (cli *testServerClient) runTestAuth(c *C) { - cli.runTests(c, nil, func(dbt *DBTest) { - dbt.mustExec(`CREATE USER 'authtest'@'%' IDENTIFIED BY '123';`) - dbt.mustExec(`CREATE ROLE 'authtest_r1'@'%';`) - dbt.mustExec(`GRANT ALL on test.* to 'authtest'`) - dbt.mustExec(`GRANT authtest_r1 to 'authtest'`) - dbt.mustExec(`SET DEFAULT ROLE authtest_r1 TO authtest`) +func (cli *testingServerClient) runTestAuth(t *testing.T) { + cli.runTests(t, nil, func(dbt *testkit.DBTestKit) { + dbt.MustExec(`CREATE USER 'authtest'@'%' IDENTIFIED BY '123';`) + dbt.MustExec(`CREATE ROLE 'authtest_r1'@'%';`) + dbt.MustExec(`GRANT ALL on test.* to 'authtest'`) + dbt.MustExec(`GRANT authtest_r1 to 'authtest'`) + dbt.MustExec(`SET DEFAULT ROLE authtest_r1 TO authtest`) }) - cli.runTests(c, func(config *mysql.Config) { + cli.runTests(t, func(config *mysql.Config) { config.User = "authtest" config.Passwd = "123" - }, func(dbt *DBTest) { - dbt.mustExec(`USE information_schema;`) + }, func(dbt *testkit.DBTestKit) { + dbt.MustExec(`USE information_schema;`) }) db, err := sql.Open("mysql", cli.getDSN(func(config *mysql.Config) { config.User = "authtest" config.Passwd = "456" })) - c.Assert(err, IsNil) + require.NoError(t, err) _, err = db.Query("USE information_schema;") - c.Assert(err, NotNil, Commentf("Wrong password should be failed")) + require.NotNilf(t, err, "Wrong password should be failed") err = db.Close() - c.Assert(err, IsNil) + require.NoError(t, err) // Test for loading active roles. db, err = sql.Open("mysql", cli.getDSN(func(config *mysql.Config) { config.User = "authtest" config.Passwd = "123" })) - c.Assert(err, IsNil) + require.NoError(t, err) rows, err := db.Query("select current_role;") - c.Assert(err, IsNil) - c.Assert(rows.Next(), IsTrue) + require.NoError(t, err) + require.True(t, rows.Next()) var outA string err = rows.Scan(&outA) - c.Assert(err, IsNil) - c.Assert(outA, Equals, "`authtest_r1`@`%`") + require.NoError(t, err) + require.Equal(t, "`authtest_r1`@`%`", outA) err = db.Close() - c.Assert(err, IsNil) + require.NoError(t, err) // Test login use IP that not exists in mysql.user. - cli.runTests(c, nil, func(dbt *DBTest) { - dbt.mustExec(`CREATE USER 'authtest2'@'localhost' IDENTIFIED BY '123';`) - dbt.mustExec(`GRANT ALL on test.* to 'authtest2'@'localhost'`) + cli.runTests(t, nil, func(dbt *testkit.DBTestKit) { + dbt.MustExec(`CREATE USER 'authtest2'@'localhost' IDENTIFIED BY '123';`) + dbt.MustExec(`GRANT ALL on test.* to 'authtest2'@'localhost'`) }) - cli.runTests(c, func(config *mysql.Config) { + cli.runTests(t, func(config *mysql.Config) { config.User = "authtest2" config.Passwd = "123" - }, func(dbt *DBTest) { - dbt.mustExec(`USE information_schema;`) + }, func(dbt *testkit.DBTestKit) { + dbt.MustExec(`USE information_schema;`) }) } @@ -1669,31 +1669,31 @@ func (cli *testServerClient) runTestIssue22646(c *C) { }) } -func (cli *testServerClient) runTestIssue3682(c *C) { - cli.runTests(c, nil, func(dbt *DBTest) { - dbt.mustExec(`CREATE USER 'issue3682'@'%' IDENTIFIED BY '123';`) - dbt.mustExec(`GRANT ALL on test.* to 'issue3682'`) - dbt.mustExec(`GRANT ALL on mysql.* to 'issue3682'`) +func (cli *testingServerClient) runTestIssue3682(t *testing.T) { + cli.runTests(t, nil, func(dbt *testkit.DBTestKit) { + dbt.MustExec(`CREATE USER 'issue3682'@'%' IDENTIFIED BY '123';`) + dbt.MustExec(`GRANT ALL on test.* to 'issue3682'`) + dbt.MustExec(`GRANT ALL on mysql.* to 'issue3682'`) }) - cli.runTests(c, func(config *mysql.Config) { + cli.runTests(t, func(config *mysql.Config) { config.User = "issue3682" config.Passwd = "123" - }, func(dbt *DBTest) { - dbt.mustExec(`USE mysql;`) + }, func(dbt *testkit.DBTestKit) { + dbt.MustExec(`USE mysql;`) }) db, err := sql.Open("mysql", cli.getDSN(func(config *mysql.Config) { config.User = "issue3682" config.Passwd = "wrong_password" config.DBName = "non_existing_schema" })) - c.Assert(err, IsNil) + require.NoError(t, err) defer func() { err := db.Close() - c.Assert(err, IsNil) + require.NoError(t, err) }() err = db.Ping() - c.Assert(err, NotNil) - c.Assert(err.Error(), Equals, "Error 1045: Access denied for user 'issue3682'@'127.0.0.1' (using password: YES)") + require.Error(t, err) + require.Equal(t, "Error 1045: Access denied for user 'issue3682'@'127.0.0.1' (using password: YES)", err.Error()) } func (cli *testServerClient) runTestDBNameEscape(c *C) { diff --git a/server/tidb_test.go b/server/tidb_test.go index f6996b7b5ce8b..19130c159bcee 100644 --- a/server/tidb_test.go +++ b/server/tidb_test.go @@ -305,10 +305,13 @@ func (ts *tidbTestSuite) TestErrorCode(c *C) { ts.runTestErrorCode(c) } -func (ts *tidbTestSuite) TestAuth(c *C) { - c.Parallel() - ts.runTestAuth(c) - ts.runTestIssue3682(c) +func TestAuth(t *testing.T) { + t.Parallel() + ts, cleanup := createTiDBTest(t) + defer cleanup() + + ts.runTestAuth(t) + ts.runTestIssue3682(t) } func (ts *tidbTestSuite) TestIssues(c *C) { From c718343c24190d5826b050409c5101a144e97006 Mon Sep 17 00:00:00 2001 From: yedamo Date: Sun, 7 Nov 2021 14:58:41 +0800 Subject: [PATCH 12/55] server: migrate `TestUint64` --- server/server_test.go | 12 ++++++------ server/tidb_test.go | 8 ++++++-- 2 files changed, 12 insertions(+), 8 deletions(-) diff --git a/server/server_test.go b/server/server_test.go index 2aeebb5694634..301464e8f5159 100644 --- a/server/server_test.go +++ b/server/server_test.go @@ -308,22 +308,22 @@ func (cli *testServerClient) runTestRegression(c *C, overrider configOverrider, }) } -func (cli *testServerClient) runTestPrepareResultFieldType(t *C) { +func (cli *testingServerClient) runTestPrepareResultFieldType(t *testing.T) { var param int64 = 83 - cli.runTests(t, nil, func(dbt *DBTest) { - stmt, err := dbt.db.Prepare(`SELECT ?`) + cli.runTests(t, nil, func(dbt *testkit.DBTestKit) { + stmt, err := dbt.GetDB().Prepare(`SELECT ?`) if err != nil { - dbt.Fatal(err) + t.Fatal(err) } defer stmt.Close() row := stmt.QueryRow(param) var result int64 err = row.Scan(&result) if err != nil { - dbt.Fatal(err) + t.Fatal(err) } if result != param { - dbt.Fatal("Unexpected result value") + t.Fatal("Unexpected result value") } }) } diff --git a/server/tidb_test.go b/server/tidb_test.go index 19130c159bcee..5f0c92f30df1e 100644 --- a/server/tidb_test.go +++ b/server/tidb_test.go @@ -236,8 +236,12 @@ func (ts *tidbTestSuite) TestRegression(c *C) { } } -func (ts *tidbTestSuite) TestUint64(c *C) { - ts.runTestPrepareResultFieldType(c) +func TestUint64(t *testing.T) { + t.Parallel() + ts, cleanup := createTiDBTest(t) + defer cleanup() + + ts.runTestPrepareResultFieldType(t) } func (ts *tidbTestSuite) TestSpecialType(c *C) { From 18e4d41f5a17abf18590136083778578e133d576 Mon Sep 17 00:00:00 2001 From: yedamo Date: Sun, 7 Nov 2021 15:10:29 +0800 Subject: [PATCH 13/55] server: migrate `TestSpecialType` --- server/server_test.go | 59 +++++++++++++++++++++++++++++++++++-------- server/tidb_test.go | 9 ++++--- 2 files changed, 54 insertions(+), 14 deletions(-) diff --git a/server/server_test.go b/server/server_test.go index 301464e8f5159..e889d7c79cbe0 100644 --- a/server/server_test.go +++ b/server/server_test.go @@ -189,6 +189,43 @@ func (cli *testServerClient) runTestsOnNewDB(c *C, overrider configOverrider, db } } +// runTestsOnNewDB runs tests using a specified database which will be created before the test and destroyed after the test. +func (cli *testingServerClient) runTestsOnNewDB(t *testing.T, overrider configOverrider, dbName string, tests ...func(dbt *testkit.DBTestKit)) { + dsn := cli.getDSN(overrider, func(config *mysql.Config) { + config.DBName = "" + }) + db, err := sql.Open("mysql", dsn) + require.NoErrorf(t, err, "Error connecting") + defer func() { + err := db.Close() + require.NoError(t, err) + }() + + _, err = db.Exec(fmt.Sprintf("DROP DATABASE IF EXISTS `%s`;", dbName)) + if err != nil { + fmt.Println(err) + } + require.NoErrorf(t, err, "Error drop database %s: %s", dbName, err) + + _, err = db.Exec(fmt.Sprintf("CREATE DATABASE `%s`;", dbName)) + require.NoErrorf(t, err, "Error create database %s: %s", dbName, err) + + defer func() { + _, err = db.Exec(fmt.Sprintf("DROP DATABASE IF EXISTS `%s`;", dbName)) + require.NoErrorf(t, err, "Error drop database %s: %s", dbName, err) + }() + + _, err = db.Exec(fmt.Sprintf("USE `%s`;", dbName)) + require.NoErrorf(t, err, "Error use database %s: %s", dbName, err) + + dbt := testkit.NewDBTestKit(t, db) + for _, test := range tests { + test(dbt) + // to fix : no db selected + _, _ = dbt.GetDB().Exec("DROP TABLE IF EXISTS test") + } +} + type DBTest struct { *C db *sql.DB @@ -328,21 +365,21 @@ func (cli *testingServerClient) runTestPrepareResultFieldType(t *testing.T) { }) } -func (cli *testServerClient) runTestSpecialType(t *C) { - cli.runTestsOnNewDB(t, nil, "SpecialType", func(dbt *DBTest) { - dbt.mustExec("create table test (a decimal(10, 5), b datetime, c time, d bit(8))") - dbt.mustExec("insert test values (1.4, '2012-12-21 12:12:12', '4:23:34', b'1000')") - rows := dbt.mustQuery("select * from test where a > ?", 0) - t.Assert(rows.Next(), IsTrue) +func (cli *testingServerClient) runTestSpecialType(t *testing.T) { + cli.runTestsOnNewDB(t, nil, "SpecialType", func(dbt *testkit.DBTestKit) { + dbt.MustExec("create table test (a decimal(10, 5), b datetime, c time, d bit(8))") + dbt.MustExec("insert test values (1.4, '2012-12-21 12:12:12', '4:23:34', b'1000')") + rows := dbt.MustQuery("select * from test where a > ?", 0) + require.True(t, rows.Next()) var outA float64 var outB, outC string var outD []byte err := rows.Scan(&outA, &outB, &outC, &outD) - t.Assert(err, IsNil) - t.Assert(outA, Equals, 1.4) - t.Assert(outB, Equals, "2012-12-21 12:12:12") - t.Assert(outC, Equals, "04:23:34") - t.Assert(outD, BytesEquals, []byte{8}) + require.NoError(t, err) + require.Equal(t, 1.4, outA) + require.Equal(t, "2012-12-21 12:12:12", outB) + require.Equal(t, "04:23:34", outC) + require.Equal(t, []byte{8}, outD) }) } diff --git a/server/tidb_test.go b/server/tidb_test.go index 5f0c92f30df1e..42dff92a84a4b 100644 --- a/server/tidb_test.go +++ b/server/tidb_test.go @@ -244,9 +244,12 @@ func TestUint64(t *testing.T) { ts.runTestPrepareResultFieldType(t) } -func (ts *tidbTestSuite) TestSpecialType(c *C) { - c.Parallel() - ts.runTestSpecialType(c) +func TestSpecialType(t *testing.T) { + t.Parallel() + ts, cleanup := createTiDBTest(t) + defer cleanup() + + ts.runTestSpecialType(t) } func (ts *tidbTestSuite) TestPreparedString(c *C) { From 66a5eb32afc311e805b5387caebcc8f23ef48621 Mon Sep 17 00:00:00 2001 From: yedamo Date: Sun, 7 Nov 2021 15:15:11 +0800 Subject: [PATCH 14/55] server: migrate `TestPreparedString` --- server/server_test.go | 18 +++++++++--------- server/tidb_test.go | 9 ++++++--- 2 files changed, 15 insertions(+), 12 deletions(-) diff --git a/server/server_test.go b/server/server_test.go index e889d7c79cbe0..2e452ce89b5c7 100644 --- a/server/server_test.go +++ b/server/server_test.go @@ -418,17 +418,17 @@ func (cli *testServerClient) runTestClientWithCollation(t *C) { }) } -func (cli *testServerClient) runTestPreparedString(t *C) { - cli.runTestsOnNewDB(t, nil, "PreparedString", func(dbt *DBTest) { - dbt.mustExec("create table test (a char(10), b char(10))") - dbt.mustExec("insert test values (?, ?)", "abcdeabcde", "abcde") - rows := dbt.mustQuery("select * from test where 1 = ?", 1) - t.Assert(rows.Next(), IsTrue) +func (cli *testingServerClient) runTestPreparedString(t *testing.T) { + cli.runTestsOnNewDB(t, nil, "PreparedString", func(dbt *testkit.DBTestKit) { + dbt.MustExec("create table test (a char(10), b char(10))") + dbt.MustExec("insert test values (?, ?)", "abcdeabcde", "abcde") + rows := dbt.MustQuery("select * from test where 1 = ?", 1) + require.True(t, rows.Next()) var outA, outB string err := rows.Scan(&outA, &outB) - t.Assert(err, IsNil) - t.Assert(outA, Equals, "abcdeabcde") - t.Assert(outB, Equals, "abcde") + require.NoError(t, err) + require.Equal(t, "abcdeabcde", outA) + require.Equal(t, "abcde", outB) }) } diff --git a/server/tidb_test.go b/server/tidb_test.go index 42dff92a84a4b..feaa65282a7f1 100644 --- a/server/tidb_test.go +++ b/server/tidb_test.go @@ -252,9 +252,12 @@ func TestSpecialType(t *testing.T) { ts.runTestSpecialType(t) } -func (ts *tidbTestSuite) TestPreparedString(c *C) { - c.Parallel() - ts.runTestPreparedString(c) +func TestPreparedString(t *testing.T) { + t.Parallel() + ts, cleanup := createTiDBTest(t) + defer cleanup() + + ts.runTestPreparedString(t) } func (ts *tidbTestSuite) TestPreparedTimestamp(c *C) { From e9a3f7eb537adb6f25097a70b561149987c83db8 Mon Sep 17 00:00:00 2001 From: yedamo Date: Sun, 7 Nov 2021 15:25:04 +0800 Subject: [PATCH 15/55] server: migrate `TestPreparedTimestamp` --- server/server_test.go | 24 ++++++++++++------------ server/tidb_test.go | 9 ++++++--- 2 files changed, 18 insertions(+), 15 deletions(-) diff --git a/server/server_test.go b/server/server_test.go index 2e452ce89b5c7..922b4b8dd294a 100644 --- a/server/server_test.go +++ b/server/server_test.go @@ -435,24 +435,24 @@ func (cli *testingServerClient) runTestPreparedString(t *testing.T) { // runTestPreparedTimestamp does not really cover binary timestamp format, because MySQL driver in golang // does not use this format. MySQL driver in golang will convert the timestamp to a string. // This case guarantees it could work. -func (cli *testServerClient) runTestPreparedTimestamp(t *C) { - cli.runTestsOnNewDB(t, nil, "prepared_timestamp", func(dbt *DBTest) { - dbt.mustExec("create table test (a timestamp, b time)") - dbt.mustExec("set time_zone='+00:00'") - insertStmt := dbt.mustPrepare("insert test values (?, ?)") +func (cli *testingServerClient) runTestPreparedTimestamp(t *testing.T) { + cli.runTestsOnNewDB(t, nil, "prepared_timestamp", func(dbt *testkit.DBTestKit) { + dbt.MustExec("create table test (a timestamp, b time)") + dbt.MustExec("set time_zone='+00:00'") + insertStmt := dbt.MustPrepare("insert test values (?, ?)") defer insertStmt.Close() vts := time.Unix(1, 1) vt := time.Unix(-1, 1) - dbt.mustExecPrepared(insertStmt, vts, vt) - selectStmt := dbt.mustPrepare("select * from test where a = ? and b = ?") + dbt.MustExecPrepared(insertStmt, vts, vt) + selectStmt := dbt.MustPrepare("select * from test where a = ? and b = ?") defer selectStmt.Close() - rows := dbt.mustQueryPrepared(selectStmt, vts, vt) - t.Assert(rows.Next(), IsTrue) + rows := dbt.MustQueryPrepared(selectStmt, vts, vt) + require.True(t, rows.Next()) var outA, outB string err := rows.Scan(&outA, &outB) - t.Assert(err, IsNil) - t.Assert(outA, Equals, "1970-01-01 00:00:01") - t.Assert(outB, Equals, "23:59:59") + require.NoError(t, err) + require.Equal(t, "1970-01-01 00:00:01", outA) + require.Equal(t, "23:59:59", outB) }) } diff --git a/server/tidb_test.go b/server/tidb_test.go index feaa65282a7f1..2724675607f21 100644 --- a/server/tidb_test.go +++ b/server/tidb_test.go @@ -260,9 +260,12 @@ func TestPreparedString(t *testing.T) { ts.runTestPreparedString(t) } -func (ts *tidbTestSuite) TestPreparedTimestamp(c *C) { - c.Parallel() - ts.runTestPreparedTimestamp(c) +func TestPreparedTimestamp(t *testing.T) { + t.Parallel() + ts, cleanup := createTiDBTest(t) + defer cleanup() + + ts.runTestPreparedTimestamp(t) } func (ts *tidbTestSerialSuite) TestConfigDefaultValue(c *C) { From a26a8909271801234b98d24f962e1d0259802985 Mon Sep 17 00:00:00 2001 From: yedamo Date: Sun, 7 Nov 2021 15:29:30 +0800 Subject: [PATCH 16/55] server: migrate `TestConcurrentUpdate` --- server/server_test.go | 34 +++++++++++++++++----------------- server/tidb_test.go | 9 ++++++--- 2 files changed, 23 insertions(+), 20 deletions(-) diff --git a/server/server_test.go b/server/server_test.go index 922b4b8dd294a..96934a38c7a16 100644 --- a/server/server_test.go +++ b/server/server_test.go @@ -1473,36 +1473,36 @@ func (cli *testServerClient) runTestLoadData(c *C, server *Server) { }) } -func (cli *testServerClient) runTestConcurrentUpdate(c *C) { +func (cli *testingServerClient) runTestConcurrentUpdate(t *testing.T) { dbName := "Concurrent" - cli.runTestsOnNewDB(c, func(config *mysql.Config) { + cli.runTestsOnNewDB(t, func(config *mysql.Config) { config.Params["sql_mode"] = "''" - }, dbName, func(dbt *DBTest) { - dbt.mustExec("drop table if exists test2") - dbt.mustExec("create table test2 (a int, b int)") - dbt.mustExec("insert test2 values (1, 1)") - dbt.mustExec("set @@tidb_disable_txn_auto_retry = 0") + }, dbName, func(dbt *testkit.DBTestKit) { + dbt.MustExec("drop table if exists test2") + dbt.MustExec("create table test2 (a int, b int)") + dbt.MustExec("insert test2 values (1, 1)") + dbt.MustExec("set @@tidb_disable_txn_auto_retry = 0") - txn1, err := dbt.db.Begin() - c.Assert(err, IsNil) + txn1, err := dbt.GetDB().Begin() + require.NoError(t, err) _, err = txn1.Exec(fmt.Sprintf("USE `%s`;", dbName)) - c.Assert(err, IsNil) + require.NoError(t, err) - txn2, err := dbt.db.Begin() - c.Assert(err, IsNil) + txn2, err := dbt.GetDB().Begin() + require.NoError(t, err) _, err = txn2.Exec(fmt.Sprintf("USE `%s`;", dbName)) - c.Assert(err, IsNil) + require.NoError(t, err) _, err = txn2.Exec("update test2 set a = a + 1 where b = 1") - c.Assert(err, IsNil) + require.NoError(t, err) err = txn2.Commit() - c.Assert(err, IsNil) + require.NoError(t, err) _, err = txn1.Exec("update test2 set a = a + 1 where b = 1") - c.Assert(err, IsNil) + require.NoError(t, err) err = txn1.Commit() - c.Assert(err, IsNil) + require.NoError(t, err) }) } diff --git a/server/tidb_test.go b/server/tidb_test.go index 2724675607f21..1aa093ffa229e 100644 --- a/server/tidb_test.go +++ b/server/tidb_test.go @@ -308,9 +308,12 @@ func (ts *tidbTestSerialSuite) TestStmtCount(c *C) { ts.runTestStmtCount(c) } -func (ts *tidbTestSuite) TestConcurrentUpdate(c *C) { - c.Parallel() - ts.runTestConcurrentUpdate(c) +func TestConcurrentUpdate(t *testing.T) { + t.Parallel() + ts, cleanup := createTiDBTest(t) + defer cleanup() + + ts.runTestConcurrentUpdate(t) } func (ts *tidbTestSuite) TestErrorCode(c *C) { From deeea387b59e4b41637ae5b09a70d0f134710cef Mon Sep 17 00:00:00 2001 From: yedamo Date: Sun, 7 Nov 2021 17:23:11 +0800 Subject: [PATCH 17/55] server: migrate `TestErrorCode` and `TestLoadData` --- server/server_test.go | 759 +++++++++++++++++++------------------ server/tidb_serial_test.go | 28 ++ server/tidb_test.go | 17 +- testkit/dbtestkit.go | 24 ++ 4 files changed, 438 insertions(+), 390 deletions(-) create mode 100644 server/tidb_serial_test.go diff --git a/server/server_test.go b/server/server_test.go index 96934a38c7a16..6e23ad469ea6f 100644 --- a/server/server_test.go +++ b/server/server_test.go @@ -456,62 +456,63 @@ func (cli *testingServerClient) runTestPreparedTimestamp(t *testing.T) { }) } -func (cli *testServerClient) runTestLoadDataWithSelectIntoOutfile(c *C, server *Server) { - cli.runTestsOnNewDB(c, func(config *mysql.Config) { +func (cli *testingServerClient) runTestLoadDataWithSelectIntoOutfile(t *testing.T, server *Server) { + cli.runTestsOnNewDB(t, func(config *mysql.Config) { config.AllowAllFiles = true config.Params["sql_mode"] = "''" - }, "SelectIntoOutfile", func(dbt *DBTest) { - dbt.mustExec("create table t (i int, r real, d decimal(10, 5), s varchar(100), dt datetime, ts timestamp, j json)") - dbt.mustExec("insert into t values (1, 1.1, 0.1, 'a', '2000-01-01', '01:01:01', '[1]')") - dbt.mustExec("insert into t values (2, 2.2, 0.2, 'b', '2000-02-02', '02:02:02', '[1,2]')") - dbt.mustExec("insert into t values (null, null, null, null, '2000-03-03', '03:03:03', '[1,2,3]')") - dbt.mustExec("insert into t values (4, 4.4, 0.4, 'd', null, null, null)") + }, "SelectIntoOutfile", func(dbt *testkit.DBTestKit) { + dbt.MustExec("create table t (i int, r real, d decimal(10, 5), s varchar(100), dt datetime, ts timestamp, j json)") + dbt.MustExec("insert into t values (1, 1.1, 0.1, 'a', '2000-01-01', '01:01:01', '[1]')") + dbt.MustExec("insert into t values (2, 2.2, 0.2, 'b', '2000-02-02', '02:02:02', '[1,2]')") + dbt.MustExec("insert into t values (null, null, null, null, '2000-03-03', '03:03:03', '[1,2,3]')") + dbt.MustExec("insert into t values (4, 4.4, 0.4, 'd', null, null, null)") outfile := filepath.Join(os.TempDir(), fmt.Sprintf("select_into_outfile_%v_%d.csv", time.Now().UnixNano(), rand.Int())) // On windows use fmt.Sprintf("%q") to escape \ for SQL, // outfile may be 'C:\Users\genius\AppData\Local\Temp\select_into_outfile_1582732846769492000_8074605509026837941.csv' // Without quote, after SQL escape it would become: // 'C:UsersgeniusAppDataLocalTempselect_into_outfile_1582732846769492000_8074605509026837941.csv' - dbt.mustExec(fmt.Sprintf("select * from t into outfile %q", outfile)) + dbt.MustExec(fmt.Sprintf("select * from t into outfile %q", outfile)) defer func() { - c.Assert(os.Remove(outfile), IsNil) + require.NoError(t, os.Remove(outfile)) }() - dbt.mustExec("create table t1 (i int, r real, d decimal(10, 5), s varchar(100), dt datetime, ts timestamp, j json)") - dbt.mustExec(fmt.Sprintf("load data local infile %q into table t1", outfile)) + dbt.MustExec("create table t1 (i int, r real, d decimal(10, 5), s varchar(100), dt datetime, ts timestamp, j json)") + dbt.MustExec(fmt.Sprintf("load data local infile %q into table t1", outfile)) fetchResults := func(table string) [][]interface{} { var res [][]interface{} - row := dbt.mustQuery("select * from " + table + " order by i") + row := dbt.MustQuery("select * from " + table + " order by i") for row.Next() { r := make([]interface{}, 7) - c.Assert(row.Scan(&r[0], &r[1], &r[2], &r[3], &r[4], &r[5], &r[6]), IsNil) + require.NoError(t, row.Scan(&r[0], &r[1], &r[2], &r[3], &r[4], &r[5], &r[6])) res = append(res, r) } - c.Assert(row.Close(), IsNil) + require.NoError(t, row.Close()) return res } res := fetchResults("t") res1 := fetchResults("t1") - c.Assert(len(res), Equals, len(res1)) + require.Equal(t, len(res1), len(res)) for i := range res { for j := range res[i] { // using Sprintf to avoid some uncomparable types - c.Assert(fmt.Sprintf("%v", res[i][j]), Equals, fmt.Sprintf("%v", res1[i][j])) + require.Equal(t, fmt.Sprintf("%v", res1[i][j]), fmt.Sprintf("%v", res[i][j])) } } }) } -func (cli *testServerClient) runTestLoadDataForSlowLog(c *C, server *Server) { + +func (cli *testingServerClient) runTestLoadDataForSlowLog(t *testing.T, server *Server) { path := "/tmp/load_data_test.csv" fp, err := os.OpenFile(path, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0600) - c.Assert(err, IsNil) - c.Assert(fp, NotNil) + require.NoError(t, err) + require.NotNil(t, fp) defer func() { err = fp.Close() - c.Assert(err, IsNil) + require.NoError(t, err) err = os.Remove(path) - c.Assert(err, IsNil) + require.NoError(t, err) }() _, err = fp.WriteString( "1 1\n" + @@ -519,42 +520,42 @@ func (cli *testServerClient) runTestLoadDataForSlowLog(c *C, server *Server) { "3 3\n" + "4 4\n" + "5 5\n") - c.Assert(err, IsNil) + require.NoError(t, err) - cli.runTestsOnNewDB(c, func(config *mysql.Config) { + cli.runTestsOnNewDB(t, func(config *mysql.Config) { config.AllowAllFiles = true config.Params["sql_mode"] = "''" - }, "load_data_slow_query", func(dbt *DBTest) { - dbt.mustExec("create table t_slow (a int key, b int)") + }, "load_data_slow_query", func(dbt *testkit.DBTestKit) { + dbt.MustExec("create table t_slow (a int key, b int)") defer func() { - dbt.mustExec("set tidb_slow_log_threshold=300;") - dbt.mustExec("set @@global.tidb_enable_stmt_summary=0") + dbt.MustExec("set tidb_slow_log_threshold=300;") + dbt.MustExec("set @@global.tidb_enable_stmt_summary=0") }() - dbt.mustExec("set tidb_slow_log_threshold=0;") - dbt.mustExec("set @@global.tidb_enable_stmt_summary=1") + dbt.MustExec("set tidb_slow_log_threshold=0;") + dbt.MustExec("set @@global.tidb_enable_stmt_summary=1") query := fmt.Sprintf("load data local infile %q into table t_slow", path) - dbt.mustExec(query) - dbt.mustExec("insert ignore into t_slow values (1,1);") + dbt.MustExec(query) + dbt.MustExec("insert ignore into t_slow values (1,1);") checkPlan := func(rows *sql.Rows, expectPlan string) { - dbt.Check(rows.Next(), IsTrue, Commentf("unexpected data")) + require.Truef(t, rows.Next(), "unexpected data") var plan sql.NullString err = rows.Scan(&plan) - dbt.Check(err, IsNil) + require.NoError(t, err) planStr := strings.ReplaceAll(plan.String, "\t", " ") planStr = strings.ReplaceAll(planStr, "\n", " ") - c.Assert(planStr, Matches, expectPlan) + require.Regexp(t, expectPlan, planStr) } // Test for record slow log for load data statement. - rows := dbt.mustQuery("select plan from information_schema.slow_query where query like 'load data local infile % into table t_slow;' order by time desc limit 1") + rows := dbt.MustQuery("select plan from information_schema.slow_query where query like 'load data local infile % into table t_slow;' order by time desc limit 1") expectedPlan := ".*LoadData.* time.* loops.* prepare.* check_insert.* mem_insert_time:.* prefetch.* rpc.* commit_txn.*" checkPlan(rows, expectedPlan) // Test for record statements_summary for load data statement. - rows = dbt.mustQuery("select plan from information_schema.STATEMENTS_SUMMARY where QUERY_SAMPLE_TEXT like 'load data local infile %' limit 1") + rows = dbt.MustQuery("select plan from information_schema.STATEMENTS_SUMMARY where QUERY_SAMPLE_TEXT like 'load data local infile %' limit 1") checkPlan(rows, expectedPlan) // Test log normal statement after executing load date. - rows = dbt.mustQuery("select plan from information_schema.slow_query where query = 'insert ignore into t_slow values (1,1);' order by time desc limit 1") + rows = dbt.MustQuery("select plan from information_schema.slow_query where query = 'insert ignore into t_slow values (1,1);' order by time desc limit 1") expectedPlan = ".*Insert.* time.* loops.* prepare.* check_insert.* mem_insert_time:.* prefetch.* rpc.*" checkPlan(rows, expectedPlan) }) @@ -919,17 +920,17 @@ func (cli *testServerClient) checkRows(c *C, rows *sql.Rows, expectedRows ...str c.Assert(strings.Join(result, "\n"), Equals, strings.Join(expectedRows, "\n")) } -func (cli *testServerClient) runTestLoadData(c *C, server *Server) { +func (cli *testingServerClient) runTestLoadData(t *testing.T, server *Server) { // create a file and write data. path := "/tmp/load_data_test.csv" fp, err := os.Create(path) - c.Assert(err, IsNil) - c.Assert(fp, NotNil) + require.NoError(t, err) + require.NotNil(t, fp) defer func() { err = fp.Close() - c.Assert(err, IsNil) + require.NoError(t, err) err = os.Remove(path) - c.Assert(err, IsNil) + require.NoError(t, err) }() _, err = fp.WriteString("\n" + "xxx row1_col1 - row1_col2 1abc\n" + @@ -937,7 +938,7 @@ func (cli *testServerClient) runTestLoadData(c *C, server *Server) { "xxxy row3_col1 - row3_col2 \n" + "xxx row4_col1 - 900\n" + "xxx row5_col1 - row5_col3") - c.Assert(err, IsNil) + require.NoError(t, err) originalTxnTotalSizeLimit := kv.TxnTotalSizeLimit // If the MemBuffer can't be committed once in each batch, it will return an error like "transaction is too large". @@ -945,531 +946,531 @@ func (cli *testServerClient) runTestLoadData(c *C, server *Server) { defer func() { kv.TxnTotalSizeLimit = originalTxnTotalSizeLimit }() // support ClientLocalFiles capability - cli.runTestsOnNewDB(c, func(config *mysql.Config) { + cli.runTestsOnNewDB(t, func(config *mysql.Config) { config.AllowAllFiles = true config.Params["sql_mode"] = "''" - }, "LoadData", func(dbt *DBTest) { - dbt.mustExec("set @@tidb_dml_batch_size = 3") - dbt.mustExec("create table test (a varchar(255), b varchar(255) default 'default value', c int not null auto_increment, primary key(c))") - dbt.mustExec("create view v1 as select 1") - dbt.mustExec("create sequence s1") + }, "LoadData", func(dbt *testkit.DBTestKit) { + dbt.MustExec("set @@tidb_dml_batch_size = 3") + dbt.MustExec("create table test (a varchar(255), b varchar(255) default 'default value', c int not null auto_increment, primary key(c))") + dbt.MustExec("create view v1 as select 1") + dbt.MustExec("create sequence s1") // can't insert into views (in TiDB) or sequences. issue #20880 - _, err = dbt.db.Exec("load data local infile '/tmp/load_data_test.csv' into table v1") - dbt.Assert(err, NotNil) - dbt.Assert(err.Error(), Equals, "Error 1105: can only load data into base tables") - _, err = dbt.db.Exec("load data local infile '/tmp/load_data_test.csv' into table s1") - dbt.Assert(err, NotNil) - dbt.Assert(err.Error(), Equals, "Error 1105: can only load data into base tables") - - rs, err1 := dbt.db.Exec("load data local infile '/tmp/load_data_test.csv' into table test") - dbt.Assert(err1, IsNil) + _, err = dbt.GetDB().Exec("load data local infile '/tmp/load_data_test.csv' into table v1") + require.Error(t, err) + require.Equal(t, "Error 1105: can only load data into base tables", err.Error()) + _, err = dbt.GetDB().Exec("load data local infile '/tmp/load_data_test.csv' into table s1") + require.Error(t, err) + require.Equal(t, "Error 1105: can only load data into base tables", err.Error()) + + rs, err1 := dbt.GetDB().Exec("load data local infile '/tmp/load_data_test.csv' into table test") + require.NoError(t, err1) lastID, err1 := rs.LastInsertId() - dbt.Assert(err1, IsNil) - dbt.Assert(lastID, Equals, int64(1)) + require.NoError(t, err1) + require.Equal(t, int64(1), lastID) affectedRows, err1 := rs.RowsAffected() - dbt.Assert(err1, IsNil) - dbt.Assert(affectedRows, Equals, int64(5)) + require.NoError(t, err1) + require.Equal(t, int64(5), affectedRows) var ( a string b string bb sql.NullString cc int ) - rows := dbt.mustQuery("select * from test") - dbt.Check(rows.Next(), IsTrue, Commentf("unexpected data")) + rows := dbt.MustQuery("select * from test") + require.Truef(t, rows.Next(), "unexpected data") err = rows.Scan(&a, &bb, &cc) - dbt.Check(err, IsNil) - dbt.Check(a, DeepEquals, "") - dbt.Check(bb.String, DeepEquals, "") - dbt.Check(cc, DeepEquals, 1) - dbt.Check(rows.Next(), IsTrue, Commentf("unexpected data")) + require.NoError(t, err) + require.Empty(t, a) + require.Empty(t, bb.String) + require.Equal(t, 1, cc) + require.Truef(t, rows.Next(), "unexpected data") err = rows.Scan(&a, &b, &cc) - c.Assert(err, IsNil) - dbt.Check(a, DeepEquals, "xxx row2_col1") - dbt.Check(b, DeepEquals, "- row2_col2") - dbt.Check(cc, DeepEquals, 2) - dbt.Check(rows.Next(), IsTrue, Commentf("unexpected data")) + require.NoError(t, err) + require.Equal(t, "xxx row2_col1", a) + require.Equal(t, "- row2_col2", b) + require.Equal(t, 2, cc) + require.Truef(t, rows.Next(), "unexpected data") err = rows.Scan(&a, &b, &cc) - c.Assert(err, IsNil) - dbt.Check(a, DeepEquals, "xxxy row3_col1") - dbt.Check(b, DeepEquals, "- row3_col2") - dbt.Check(cc, DeepEquals, 3) - dbt.Check(rows.Next(), IsTrue, Commentf("unexpected data")) + require.NoError(t, err) + require.Equal(t, "xxxy row3_col1", a) + require.Equal(t, "- row3_col2", b) + require.Equal(t, 3, cc) + require.Truef(t, rows.Next(), "unexpected data") err = rows.Scan(&a, &b, &cc) - c.Assert(err, IsNil) - dbt.Check(a, DeepEquals, "xxx row4_col1") - dbt.Check(b, DeepEquals, "- ") - dbt.Check(cc, DeepEquals, 4) - dbt.Check(rows.Next(), IsTrue, Commentf("unexpected data")) + require.NoError(t, err) + require.Equal(t, "xxx row4_col1", a) + require.Equal(t, "- ", b) + require.Equal(t, 4, cc) + require.Truef(t, rows.Next(), "unexpected data") err = rows.Scan(&a, &b, &cc) - c.Assert(err, IsNil) - dbt.Check(a, DeepEquals, "xxx row5_col1") - dbt.Check(b, DeepEquals, "- ") - dbt.Check(cc, DeepEquals, 5) - dbt.Check(rows.Next(), IsFalse, Commentf("unexpected data")) + require.NoError(t, err) + require.Equal(t, "xxx row5_col1", a) + require.Equal(t, "- ", b) + require.Equal(t, 5, cc) + require.Falsef(t, rows.Next(), "unexpected data") rows.Close() // specify faileds and lines - dbt.mustExec("delete from test") - dbt.mustExec("set @@tidb_dml_batch_size = 3") - rs, err = dbt.db.Exec("load data local infile '/tmp/load_data_test.csv' into table test fields terminated by '\t- ' lines starting by 'xxx ' terminated by '\n'") - dbt.Assert(err, IsNil) + dbt.MustExec("delete from test") + dbt.MustExec("set @@tidb_dml_batch_size = 3") + rs, err = dbt.GetDB().Exec("load data local infile '/tmp/load_data_test.csv' into table test fields terminated by '\t- ' lines starting by 'xxx ' terminated by '\n'") + require.NoError(t, err) lastID, err = rs.LastInsertId() - dbt.Assert(err, IsNil) - dbt.Assert(lastID, Equals, int64(6)) + require.NoError(t, err) + require.Equal(t, int64(6), lastID) affectedRows, err = rs.RowsAffected() - dbt.Assert(err, IsNil) - dbt.Assert(affectedRows, Equals, int64(4)) - rows = dbt.mustQuery("select * from test") - dbt.Check(rows.Next(), IsTrue, Commentf("unexpected data")) + require.NoError(t, err) + require.Equal(t, int64(4), affectedRows) + rows = dbt.MustQuery("select * from test") + require.Truef(t, rows.Next(), "unexpected data") err = rows.Scan(&a, &b, &cc) - c.Assert(err, IsNil) - dbt.Check(a, DeepEquals, "row1_col1") - dbt.Check(b, DeepEquals, "row1_col2\t1abc") - dbt.Check(cc, DeepEquals, 6) - dbt.Check(rows.Next(), IsTrue, Commentf("unexpected data")) + require.NoError(t, err) + require.Equal(t, "row1_col1", a) + require.Equal(t, "row1_col2\t1abc", b) + require.Equal(t, 6, cc) + require.Truef(t, rows.Next(), "unexpected data") err = rows.Scan(&a, &b, &cc) - c.Assert(err, IsNil) - dbt.Check(a, DeepEquals, "row2_col1") - dbt.Check(b, DeepEquals, "row2_col2\t") - dbt.Check(cc, DeepEquals, 7) - dbt.Check(rows.Next(), IsTrue, Commentf("unexpected data")) + require.NoError(t, err) + require.Equal(t, "row2_col1", a) + require.Equal(t, "row2_col2\t", b) + require.Equal(t, 7, cc) + require.Truef(t, rows.Next(), "unexpected data") err = rows.Scan(&a, &b, &cc) - c.Assert(err, IsNil) - dbt.Check(a, DeepEquals, "row4_col1") - dbt.Check(b, DeepEquals, "\t\t900") - dbt.Check(cc, DeepEquals, 8) - dbt.Check(rows.Next(), IsTrue, Commentf("unexpected data")) + require.NoError(t, err) + require.Equal(t, "row4_col1", a) + require.Equal(t, "\t\t900", b) + require.Equal(t, 8, cc) + require.Truef(t, rows.Next(), "unexpected data") err = rows.Scan(&a, &b, &cc) - c.Assert(err, IsNil) - dbt.Check(a, DeepEquals, "row5_col1") - dbt.Check(b, DeepEquals, "\trow5_col3") - dbt.Check(cc, DeepEquals, 9) - dbt.Check(rows.Next(), IsFalse, Commentf("unexpected data")) + require.NoError(t, err) + require.Equal(t, "row5_col1", a) + require.Equal(t, "\trow5_col3", b) + require.Equal(t, 9, cc) + require.Falsef(t, rows.Next(), "unexpected data") // infile size more than a packet size(16K) - dbt.mustExec("delete from test") + dbt.MustExec("delete from test") _, err = fp.WriteString("\n") - dbt.Assert(err, IsNil) + require.NoError(t, err) for i := 6; i <= 800; i++ { _, err = fp.WriteString(fmt.Sprintf("xxx row%d_col1 - row%d_col2\n", i, i)) - dbt.Assert(err, IsNil) + require.NoError(t, err) } - dbt.mustExec("set @@tidb_dml_batch_size = 3") - rs, err = dbt.db.Exec("load data local infile '/tmp/load_data_test.csv' into table test fields terminated by '\t- ' lines starting by 'xxx ' terminated by '\n'") - dbt.Assert(err, IsNil) + dbt.MustExec("set @@tidb_dml_batch_size = 3") + rs, err = dbt.GetDB().Exec("load data local infile '/tmp/load_data_test.csv' into table test fields terminated by '\t- ' lines starting by 'xxx ' terminated by '\n'") + require.NoError(t, err) lastID, err = rs.LastInsertId() - dbt.Assert(err, IsNil) - dbt.Assert(lastID, Equals, int64(10)) + require.NoError(t, err) + require.Equal(t, int64(10), lastID) affectedRows, err = rs.RowsAffected() - dbt.Assert(err, IsNil) - dbt.Assert(affectedRows, Equals, int64(799)) - rows = dbt.mustQuery("select * from test") - dbt.Check(rows.Next(), IsTrue, Commentf("unexpected data")) + require.NoError(t, err) + require.Equal(t, int64(799), affectedRows) + rows = dbt.MustQuery("select * from test") + require.Truef(t, rows.Next(), "unexpected data") // don't support lines terminated is "" - dbt.mustExec("set @@tidb_dml_batch_size = 3") - _, err = dbt.db.Exec("load data local infile '/tmp/load_data_test.csv' into table test lines terminated by ''") - dbt.Assert(err, NotNil) + dbt.MustExec("set @@tidb_dml_batch_size = 3") + _, err = dbt.GetDB().Exec("load data local infile '/tmp/load_data_test.csv' into table test lines terminated by ''") + require.NotNil(t, err) // infile doesn't exist - dbt.mustExec("set @@tidb_dml_batch_size = 3") - _, err = dbt.db.Exec("load data local infile '/tmp/nonexistence.csv' into table test") - dbt.Assert(err, NotNil) + dbt.MustExec("set @@tidb_dml_batch_size = 3") + _, err = dbt.GetDB().Exec("load data local infile '/tmp/nonexistence.csv' into table test") + require.NotNil(t, err) }) err = fp.Close() - c.Assert(err, IsNil) + require.NoError(t, err) err = os.Remove(path) - c.Assert(err, IsNil) + require.NoError(t, err) fp, err = os.Create(path) - c.Assert(err, IsNil) - c.Assert(fp, NotNil) + require.NoError(t, err) + require.NotNil(t, fp) // Test mixed unenclosed and enclosed fields. _, err = fp.WriteString( "\"abc\",123\n" + "def,456,\n" + "hig,\"789\",") - c.Assert(err, IsNil) + require.NoError(t, err) - cli.runTestsOnNewDB(c, func(config *mysql.Config) { + cli.runTestsOnNewDB(t, func(config *mysql.Config) { config.AllowAllFiles = true config.Params["sql_mode"] = "''" - }, "LoadData", func(dbt *DBTest) { - dbt.mustExec("create table test (str varchar(10) default null, i int default null)") - dbt.mustExec("set @@tidb_dml_batch_size = 3") - _, err1 := dbt.db.Exec(`load data local infile '/tmp/load_data_test.csv' into table test FIELDS TERMINATED BY ',' enclosed by '"'`) - dbt.Assert(err1, IsNil) + }, "LoadData", func(dbt *testkit.DBTestKit) { + dbt.MustExec("create table test (str varchar(10) default null, i int default null)") + dbt.MustExec("set @@tidb_dml_batch_size = 3") + _, err1 := dbt.GetDB().Exec(`load data local infile '/tmp/load_data_test.csv' into table test FIELDS TERMINATED BY ',' enclosed by '"'`) + require.NoError(t, err1) var ( str string id int ) - rows := dbt.mustQuery("select * from test") - dbt.Check(rows.Next(), IsTrue, Commentf("unexpected data")) + rows := dbt.MustQuery("select * from test") + require.Truef(t, rows.Next(), "unexpected data") err = rows.Scan(&str, &id) - dbt.Check(err, IsNil) - dbt.Check(str, DeepEquals, "abc") - dbt.Check(id, DeepEquals, 123) - dbt.Check(rows.Next(), IsTrue, Commentf("unexpected data")) + require.NoError(t, err) + require.Equal(t, "abc", str) + require.Equal(t, 123, id) + require.Truef(t, rows.Next(), "unexpected data") err = rows.Scan(&str, &id) - c.Assert(err, IsNil) - dbt.Check(str, DeepEquals, "def") - dbt.Check(id, DeepEquals, 456) - dbt.Check(rows.Next(), IsTrue, Commentf("unexpected data")) + require.NoError(t, err) + require.Equal(t, "def", str) + require.Equal(t, 456, id) + require.Truef(t, rows.Next(), "unexpected data") err = rows.Scan(&str, &id) - c.Assert(err, IsNil) - dbt.Check(str, DeepEquals, "hig") - dbt.Check(id, DeepEquals, 789) - dbt.Check(rows.Next(), IsFalse, Commentf("unexpected data")) - dbt.mustExec("delete from test") + require.NoError(t, err) + require.Equal(t, "hig", str) + require.Equal(t, 789, id) + require.Falsef(t, rows.Next(), "unexpected data") + dbt.MustExec("delete from test") }) err = fp.Close() - c.Assert(err, IsNil) + require.NoError(t, err) err = os.Remove(path) - c.Assert(err, IsNil) + require.NoError(t, err) fp, err = os.Create(path) - c.Assert(err, IsNil) - c.Assert(fp, NotNil) + require.NoError(t, err) + require.NotNil(t, fp) // Test irregular csv file. _, err = fp.WriteString( `,\N,NULL,,` + "\n" + "00,0,000000,,\n" + `2003-03-03, 20030303,030303,\N` + "\n") - c.Assert(err, IsNil) + require.NoError(t, err) - cli.runTestsOnNewDB(c, func(config *mysql.Config) { + cli.runTestsOnNewDB(t, func(config *mysql.Config) { config.AllowAllFiles = true config.Params["sql_mode"] = "''" - }, "LoadData", func(dbt *DBTest) { - dbt.mustExec("create table test (a date, b date, c date not null, d date)") - dbt.mustExec("set @@tidb_dml_batch_size = 3") - _, err1 := dbt.db.Exec(`load data local infile '/tmp/load_data_test.csv' into table test FIELDS TERMINATED BY ','`) - dbt.Assert(err1, IsNil) + }, "LoadData", func(dbt *testkit.DBTestKit) { + dbt.MustExec("create table test (a date, b date, c date not null, d date)") + dbt.MustExec("set @@tidb_dml_batch_size = 3") + _, err1 := dbt.GetDB().Exec(`load data local infile '/tmp/load_data_test.csv' into table test FIELDS TERMINATED BY ','`) + require.NoError(t, err1) var ( a sql.NullString b sql.NullString d sql.NullString c sql.NullString ) - rows := dbt.mustQuery("select * from test") - dbt.Check(rows.Next(), IsTrue, Commentf("unexpected data")) + rows := dbt.MustQuery("select * from test") + require.Truef(t, rows.Next(), "unexpected data") err = rows.Scan(&a, &b, &c, &d) - dbt.Check(err, IsNil) - dbt.Check(a.String, Equals, "0000-00-00") - dbt.Check(b.String, Equals, "") - dbt.Check(c.String, Equals, "0000-00-00") - dbt.Check(d.String, Equals, "0000-00-00") - dbt.Check(rows.Next(), IsTrue, Commentf("unexpected data")) + require.NoError(t, err) + require.Equal(t, "0000-00-00", a.String) + require.Empty(t, b.String) + require.Equal(t, "0000-00-00", c.String) + require.Equal(t, "0000-00-00", d.String) + require.Truef(t, rows.Next(), "unexpected data") err = rows.Scan(&a, &b, &c, &d) - dbt.Check(err, IsNil) - dbt.Check(a.String, Equals, "0000-00-00") - dbt.Check(b.String, Equals, "0000-00-00") - dbt.Check(c.String, Equals, "0000-00-00") - dbt.Check(d.String, Equals, "0000-00-00") - dbt.Check(rows.Next(), IsTrue, Commentf("unexpected data")) + require.NoError(t, err) + require.Equal(t, "0000-00-00", a.String) + require.Equal(t, "0000-00-00", b.String) + require.Equal(t, "0000-00-00", c.String) + require.Equal(t, "0000-00-00", d.String) + require.Truef(t, rows.Next(), "unexpected data") err = rows.Scan(&a, &b, &c, &d) - dbt.Check(err, IsNil) - dbt.Check(a.String, Equals, "2003-03-03") - dbt.Check(b.String, Equals, "2003-03-03") - dbt.Check(c.String, Equals, "2003-03-03") - dbt.Check(d.String, Equals, "") - dbt.Check(rows.Next(), IsFalse, Commentf("unexpected data")) - dbt.mustExec("delete from test") + require.NoError(t, err) + require.Equal(t, "2003-03-03", a.String) + require.Equal(t, "2003-03-03", b.String) + require.Equal(t, "2003-03-03", c.String) + require.Equal(t, "", d.String) + require.Falsef(t, rows.Next(), "unexpected data") + dbt.MustExec("delete from test") }) err = fp.Close() - c.Assert(err, IsNil) + require.NoError(t, err) err = os.Remove(path) - c.Assert(err, IsNil) + require.NoError(t, err) fp, err = os.Create(path) - c.Assert(err, IsNil) - c.Assert(fp, NotNil) + require.NoError(t, err) + require.NotNil(t, fp) // Test double enclosed. _, err = fp.WriteString( `"field1","field2"` + "\n" + `"a""b","cd""ef"` + "\n" + `"a"b",c"d"e` + "\n") - c.Assert(err, IsNil) + require.NoError(t, err) - cli.runTestsOnNewDB(c, func(config *mysql.Config) { + cli.runTestsOnNewDB(t, func(config *mysql.Config) { config.AllowAllFiles = true config.Params["sql_mode"] = "''" - }, "LoadData", func(dbt *DBTest) { - dbt.mustExec("create table test (a varchar(20), b varchar(20))") - dbt.mustExec("set @@tidb_dml_batch_size = 3") - _, err1 := dbt.db.Exec(`load data local infile '/tmp/load_data_test.csv' into table test FIELDS TERMINATED BY ',' enclosed by '"'`) - dbt.Assert(err1, IsNil) + }, "LoadData", func(dbt *testkit.DBTestKit) { + dbt.MustExec("create table test (a varchar(20), b varchar(20))") + dbt.MustExec("set @@tidb_dml_batch_size = 3") + _, err1 := dbt.GetDB().Exec(`load data local infile '/tmp/load_data_test.csv' into table test FIELDS TERMINATED BY ',' enclosed by '"'`) + require.NoError(t, err1) var ( a sql.NullString b sql.NullString ) - rows := dbt.mustQuery("select * from test") - dbt.Check(rows.Next(), IsTrue, Commentf("unexpected data")) + rows := dbt.MustQuery("select * from test") + require.Truef(t, rows.Next(), "unexpected data") err = rows.Scan(&a, &b) - dbt.Check(err, IsNil) - dbt.Check(a.String, Equals, "field1") - dbt.Check(b.String, Equals, "field2") - dbt.Check(rows.Next(), IsTrue, Commentf("unexpected data")) + require.NoError(t, err) + require.Equal(t, "field1", a.String) + require.Equal(t, "field2", b.String) + require.Truef(t, rows.Next(), "unexpected data") err = rows.Scan(&a, &b) - c.Assert(err, IsNil) - dbt.Check(a.String, Equals, `a"b`) - dbt.Check(b.String, Equals, `cd"ef`) - dbt.Check(rows.Next(), IsTrue, Commentf("unexpected data")) + require.NoError(t, err) + require.Equal(t, `a"b`, a.String) + require.Equal(t, `cd"ef`, b.String) + require.Truef(t, rows.Next(), "unexpected data") err = rows.Scan(&a, &b) - c.Assert(err, IsNil) - dbt.Check(a.String, Equals, `a"b`) - dbt.Check(b.String, Equals, `c"d"e`) - dbt.Check(rows.Next(), IsFalse, Commentf("unexpected data")) - dbt.mustExec("delete from test") + require.NoError(t, err) + require.Equal(t, `a"b`, a.String) + require.Equal(t, `c"d"e`, b.String) + require.Falsef(t, rows.Next(), "unexpected data") + dbt.MustExec("delete from test") }) err = fp.Close() - c.Assert(err, IsNil) + require.NoError(t, err) err = os.Remove(path) - c.Assert(err, IsNil) + require.NoError(t, err) fp, err = os.Create(path) - c.Assert(err, IsNil) - c.Assert(fp, NotNil) + require.NoError(t, err) + require.NotNil(t, fp) // Test OPTIONALLY _, err = fp.WriteString( `"a,b,c` + "\n" + `"1",2,"3"` + "\n") - c.Assert(err, IsNil) + require.NoError(t, err) - cli.runTestsOnNewDB(c, func(config *mysql.Config) { + cli.runTestsOnNewDB(t, func(config *mysql.Config) { config.AllowAllFiles = true - }, "LoadData", func(dbt *DBTest) { - dbt.mustExec("create table test (id INT NOT NULL PRIMARY KEY, b INT, c varchar(10))") - dbt.mustExec("set @@tidb_dml_batch_size = 3") - _, err1 := dbt.db.Exec(`load data local infile '/tmp/load_data_test.csv' into table test FIELDS TERMINATED BY ',' OPTIONALLY ENCLOSED BY '\"' IGNORE 1 LINES`) - dbt.Assert(err1, IsNil) + }, "LoadData", func(dbt *testkit.DBTestKit) { + dbt.MustExec("create table test (id INT NOT NULL PRIMARY KEY, b INT, c varchar(10))") + dbt.MustExec("set @@tidb_dml_batch_size = 3") + _, err1 := dbt.GetDB().Exec(`load data local infile '/tmp/load_data_test.csv' into table test FIELDS TERMINATED BY ',' OPTIONALLY ENCLOSED BY '\"' IGNORE 1 LINES`) + require.NoError(t, err1) var ( a int b int c sql.NullString ) - rows := dbt.mustQuery("select * from test") - dbt.Check(rows.Next(), IsTrue, Commentf("unexpected data")) + rows := dbt.MustQuery("select * from test") + require.Truef(t, rows.Next(), "unexpected data") err = rows.Scan(&a, &b, &c) - dbt.Check(err, IsNil) - dbt.Check(a, Equals, 1) - dbt.Check(b, Equals, 2) - dbt.Check(c.String, Equals, "3") - dbt.Check(rows.Next(), IsFalse, Commentf("unexpected data")) - dbt.mustExec("delete from test") + require.NoError(t, err) + require.Equal(t, 1, a) + require.Equal(t, 2, b) + require.Equal(t, "3", c.String) + require.Falsef(t, rows.Next(), "unexpected data") + dbt.MustExec("delete from test") }) // unsupport ClientLocalFiles capability server.capability ^= tmysql.ClientLocalFiles - cli.runTestsOnNewDB(c, func(config *mysql.Config) { + cli.runTestsOnNewDB(t, func(config *mysql.Config) { config.AllowAllFiles = true - }, "LoadData", func(dbt *DBTest) { - dbt.mustExec("create table test (a varchar(255), b varchar(255) default 'default value', c int not null auto_increment, primary key(c))") - dbt.mustExec("set @@tidb_dml_batch_size = 3") - _, err = dbt.db.Exec("load data local infile '/tmp/load_data_test.csv' into table test") - dbt.Assert(err, NotNil) - checkErrorCode(c, err, errno.ErrNotAllowedCommand) + }, "LoadData", func(dbt *testkit.DBTestKit) { + dbt.MustExec("create table test (a varchar(255), b varchar(255) default 'default value', c int not null auto_increment, primary key(c))") + dbt.MustExec("set @@tidb_dml_batch_size = 3") + _, err = dbt.GetDB().Exec("load data local infile '/tmp/load_data_test.csv' into table test") + require.Error(t, err) + checkErrorCode(t, err, errno.ErrNotAllowedCommand) }) server.capability |= tmysql.ClientLocalFiles err = fp.Close() - c.Assert(err, IsNil) + require.NoError(t, err) err = os.Remove(path) - c.Assert(err, IsNil) + require.NoError(t, err) fp, err = os.Create(path) - c.Assert(err, IsNil) - c.Assert(fp, NotNil) + require.NoError(t, err) + require.NotNil(t, fp) // Test OPTIONALLY _, err = fp.WriteString( `1,2` + "\n" + `3,4` + "\n") - c.Assert(err, IsNil) + require.NoError(t, err) - cli.runTestsOnNewDB(c, func(config *mysql.Config) { + cli.runTestsOnNewDB(t, func(config *mysql.Config) { config.AllowAllFiles = true config.Params["sql_mode"] = "''" - }, "LoadData", func(dbt *DBTest) { - dbt.mustExec("drop table if exists pn") - dbt.mustExec("create table pn (c1 int, c2 int)") - dbt.mustExec("set @@tidb_dml_batch_size = 1") - _, err1 := dbt.db.Exec(`load data local infile '/tmp/load_data_test.csv' into table pn FIELDS TERMINATED BY ','`) - dbt.Assert(err1, IsNil) + }, "LoadData", func(dbt *testkit.DBTestKit) { + dbt.MustExec("drop table if exists pn") + dbt.MustExec("create table pn (c1 int, c2 int)") + dbt.MustExec("set @@tidb_dml_batch_size = 1") + _, err1 := dbt.GetDB().Exec(`load data local infile '/tmp/load_data_test.csv' into table pn FIELDS TERMINATED BY ','`) + require.NoError(t, err1) var ( a int b int ) - rows := dbt.mustQuery("select * from pn") - dbt.Check(rows.Next(), IsTrue, Commentf("unexpected data")) + rows := dbt.MustQuery("select * from pn") + require.Truef(t, rows.Next(), "unexpected data") err = rows.Scan(&a, &b) - dbt.Check(err, IsNil) - dbt.Check(a, Equals, 1) - dbt.Check(b, Equals, 2) - dbt.Check(rows.Next(), IsTrue, Commentf("unexpected data")) + require.NoError(t, err) + require.Equal(t, 1, a) + require.Equal(t, 2, b) + require.Truef(t, rows.Next(), "unexpected data") err = rows.Scan(&a, &b) - dbt.Check(err, IsNil) - dbt.Check(a, Equals, 3) - dbt.Check(b, Equals, 4) - dbt.Check(rows.Next(), IsFalse, Commentf("unexpected data")) + require.NoError(t, err) + require.Equal(t, 3, a) + require.Equal(t, 4, b) + require.Falsef(t, rows.Next(), "unexpected data") // fail error processing test - dbt.Assert(failpoint.Enable("github.com/pingcap/tidb/executor/commitOneTaskErr", "return"), IsNil) - _, err1 = dbt.db.Exec(`load data local infile '/tmp/load_data_test.csv' into table pn FIELDS TERMINATED BY ','`) + require.NoError(t, failpoint.Enable("github.com/pingcap/tidb/executor/commitOneTaskErr", "return")) + _, err1 = dbt.GetDB().Exec(`load data local infile '/tmp/load_data_test.csv' into table pn FIELDS TERMINATED BY ','`) mysqlErr, ok := err1.(*mysql.MySQLError) - dbt.Assert(ok, IsTrue) - dbt.Assert(mysqlErr.Message, Equals, "mock commit one task error") - dbt.Assert(failpoint.Disable("github.com/pingcap/tidb/executor/commitOneTaskErr"), IsNil) + require.True(t, ok) + require.Equal(t, "mock commit one task error", mysqlErr.Message) + require.NoError(t, failpoint.Disable("github.com/pingcap/tidb/executor/commitOneTaskErr")) - dbt.mustExec("drop table if exists pn") + dbt.MustExec("drop table if exists pn") }) err = fp.Close() - c.Assert(err, IsNil) + require.NoError(t, err) err = os.Remove(path) - c.Assert(err, IsNil) + require.NoError(t, err) fp, err = os.Create(path) - c.Assert(err, IsNil) - c.Assert(fp, NotNil) + require.NoError(t, err) + require.NotNil(t, fp) // Test Column List Specification _, err = fp.WriteString( `1,2` + "\n" + `3,4` + "\n") - c.Assert(err, IsNil) + require.NoError(t, err) - cli.runTestsOnNewDB(c, func(config *mysql.Config) { + cli.runTestsOnNewDB(t, func(config *mysql.Config) { config.AllowAllFiles = true config.Params["sql_mode"] = "''" - }, "LoadData", func(dbt *DBTest) { - dbt.mustExec("drop table if exists pn") - dbt.mustExec("create table pn (c1 int, c2 int)") - dbt.mustExec("set @@tidb_dml_batch_size = 1") - _, err1 := dbt.db.Exec(`load data local infile '/tmp/load_data_test.csv' into table pn FIELDS TERMINATED BY ',' (c1, c2)`) - dbt.Assert(err1, IsNil) + }, "LoadData", func(dbt *testkit.DBTestKit) { + dbt.MustExec("drop table if exists pn") + dbt.MustExec("create table pn (c1 int, c2 int)") + dbt.MustExec("set @@tidb_dml_batch_size = 1") + _, err1 := dbt.GetDB().Exec(`load data local infile '/tmp/load_data_test.csv' into table pn FIELDS TERMINATED BY ',' (c1, c2)`) + require.NoError(t, err1) var ( a int b int ) - rows := dbt.mustQuery("select * from pn") - dbt.Check(rows.Next(), IsTrue, Commentf("unexpected data")) + rows := dbt.MustQuery("select * from pn") + require.Truef(t, rows.Next(), "unexpected data") err = rows.Scan(&a, &b) - dbt.Check(err, IsNil) - dbt.Check(a, Equals, 1) - dbt.Check(b, Equals, 2) - dbt.Check(rows.Next(), IsTrue, Commentf("unexpected data")) + require.NoError(t, err) + require.Equal(t, 1, a) + require.Equal(t, 2, b) + require.Truef(t, rows.Next(), "unexpected data") err = rows.Scan(&a, &b) - dbt.Check(err, IsNil) - dbt.Check(a, Equals, 3) - dbt.Check(b, Equals, 4) - dbt.Check(rows.Next(), IsFalse, Commentf("unexpected data")) + require.NoError(t, err) + require.Equal(t, 3, a) + require.Equal(t, 4, b) + require.Falsef(t, rows.Next(), "unexpected data") - dbt.mustExec("drop table if exists pn") + dbt.MustExec("drop table if exists pn") }) err = fp.Close() - c.Assert(err, IsNil) + require.NoError(t, err) err = os.Remove(path) - c.Assert(err, IsNil) + require.NoError(t, err) fp, err = os.Create(path) - c.Assert(err, IsNil) - c.Assert(fp, NotNil) + require.NoError(t, err) + require.NotNil(t, fp) // Test Column List Specification _, err = fp.WriteString( `1,2,3` + "\n" + `4,5,6` + "\n") - c.Assert(err, IsNil) + require.NoError(t, err) - cli.runTestsOnNewDB(c, func(config *mysql.Config) { + cli.runTestsOnNewDB(t, func(config *mysql.Config) { config.AllowAllFiles = true config.Params["sql_mode"] = "''" - }, "LoadData", func(dbt *DBTest) { - dbt.mustExec("drop table if exists pn") - dbt.mustExec("create table pn (c1 int, c2 int, c3 int)") - dbt.mustExec("set @@tidb_dml_batch_size = 1") - _, err1 := dbt.db.Exec(`load data local infile '/tmp/load_data_test.csv' into table pn FIELDS TERMINATED BY ',' (c1, @dummy)`) - dbt.Assert(err1, IsNil) + }, "LoadData", func(dbt *testkit.DBTestKit) { + dbt.MustExec("drop table if exists pn") + dbt.MustExec("create table pn (c1 int, c2 int, c3 int)") + dbt.MustExec("set @@tidb_dml_batch_size = 1") + _, err1 := dbt.GetDB().Exec(`load data local infile '/tmp/load_data_test.csv' into table pn FIELDS TERMINATED BY ',' (c1, @dummy)`) + require.NoError(t, err1) var ( a int b sql.NullString c sql.NullString ) - rows := dbt.mustQuery("select * from pn") - dbt.Check(rows.Next(), IsTrue, Commentf("unexpected data")) + rows := dbt.MustQuery("select * from pn") + require.Truef(t, rows.Next(), "unexpected data") err = rows.Scan(&a, &b, &c) - dbt.Check(err, IsNil) - dbt.Check(a, Equals, 1) - dbt.Check(b.String, Equals, "") - dbt.Check(c.String, Equals, "") - dbt.Check(rows.Next(), IsTrue, Commentf("unexpected data")) + require.NoError(t, err) + require.Equal(t, 1, a) + require.Empty(t, b.String) + require.Empty(t, c.String) + require.Truef(t, rows.Next(), "unexpected data") err = rows.Scan(&a, &b, &c) - dbt.Check(err, IsNil) - dbt.Check(a, Equals, 4) - dbt.Check(b.String, Equals, "") - dbt.Check(c.String, Equals, "") - dbt.Check(rows.Next(), IsFalse, Commentf("unexpected data")) + require.NoError(t, err) + require.Equal(t, 4, a) + require.Empty(t, b.String) + require.Empty(t, c.String) + require.Falsef(t, rows.Next(), "unexpected data") - dbt.mustExec("drop table if exists pn") + dbt.MustExec("drop table if exists pn") }) err = fp.Close() - c.Assert(err, IsNil) + require.NoError(t, err) err = os.Remove(path) - c.Assert(err, IsNil) + require.NoError(t, err) fp, err = os.Create(path) - c.Assert(err, IsNil) - c.Assert(fp, NotNil) + require.NoError(t, err) + require.NotNil(t, fp) // Test Input Preprocessing _, err = fp.WriteString( `1,2,3` + "\n" + `4,5,6` + "\n") - c.Assert(err, IsNil) + require.NoError(t, err) - cli.runTestsOnNewDB(c, func(config *mysql.Config) { + cli.runTestsOnNewDB(t, func(config *mysql.Config) { config.AllowAllFiles = true config.Params["sql_mode"] = "''" - }, "LoadData", func(dbt *DBTest) { - dbt.mustExec("drop table if exists pn") - dbt.mustExec("create table pn (c1 int, c2 int, c3 int)") - dbt.mustExec("set @@tidb_dml_batch_size = 1") - _, err1 := dbt.db.Exec(`load data local infile '/tmp/load_data_test.csv' into table pn FIELDS TERMINATED BY ',' (c1, @val1, @val2) SET c3 = @val2 * 100, c2 = CAST(@val1 AS UNSIGNED)`) - dbt.Assert(err1, IsNil) + }, "LoadData", func(dbt *testkit.DBTestKit) { + dbt.MustExec("drop table if exists pn") + dbt.MustExec("create table pn (c1 int, c2 int, c3 int)") + dbt.MustExec("set @@tidb_dml_batch_size = 1") + _, err1 := dbt.GetDB().Exec(`load data local infile '/tmp/load_data_test.csv' into table pn FIELDS TERMINATED BY ',' (c1, @val1, @val2) SET c3 = @val2 * 100, c2 = CAST(@val1 AS UNSIGNED)`) + require.NoError(t, err1) var ( a int b int c int ) - rows := dbt.mustQuery("select * from pn") - dbt.Check(rows.Next(), IsTrue, Commentf("unexpected data")) + rows := dbt.MustQuery("select * from pn") + require.Truef(t, rows.Next(), "unexpected data") err = rows.Scan(&a, &b, &c) - dbt.Check(err, IsNil) - dbt.Check(a, Equals, 1) - dbt.Check(b, Equals, 2) - dbt.Check(c, Equals, 300) - dbt.Check(rows.Next(), IsTrue, Commentf("unexpected data")) + require.NoError(t, err) + require.Equal(t, 1, a) + require.Equal(t, 2, b) + require.Equal(t, 300, c) + require.Truef(t, rows.Next(), "unexpected data") err = rows.Scan(&a, &b, &c) - dbt.Check(err, IsNil) - dbt.Check(a, Equals, 4) - dbt.Check(b, Equals, 5) - dbt.Check(c, Equals, 600) - dbt.Check(rows.Next(), IsFalse, Commentf("unexpected data")) + require.NoError(t, err) + require.Equal(t, 4, a) + require.Equal(t, 5, b) + require.Equal(t, 600, c) + require.Falsef(t, rows.Next(), "unexpected data") - dbt.mustExec("drop table if exists pn") + dbt.MustExec("drop table if exists pn") }) } @@ -1528,67 +1529,67 @@ func (cli *testServerClient) runTestExplainForConn(c *C) { }) } -func (cli *testServerClient) runTestErrorCode(c *C) { - cli.runTestsOnNewDB(c, nil, "ErrorCode", func(dbt *DBTest) { - dbt.mustExec("create table test (c int PRIMARY KEY);") - dbt.mustExec("insert into test values (1);") - txn1, err := dbt.db.Begin() - c.Assert(err, IsNil) +func (cli *testingServerClient) runTestErrorCode(t *testing.T) { + cli.runTestsOnNewDB(t, nil, "ErrorCode", func(dbt *testkit.DBTestKit) { + dbt.MustExec("create table test (c int PRIMARY KEY);") + dbt.MustExec("insert into test values (1);") + txn1, err := dbt.GetDB().Begin() + require.NoError(t, err) _, err = txn1.Exec("insert into test values(1)") - c.Assert(err, IsNil) + require.NoError(t, err) err = txn1.Commit() - checkErrorCode(c, err, errno.ErrDupEntry) + checkErrorCode(t, err, errno.ErrDupEntry) // Schema errors - txn2, err := dbt.db.Begin() - c.Assert(err, IsNil) + txn2, err := dbt.GetDB().Begin() + require.NoError(t, err) _, err = txn2.Exec("use db_not_exists;") - checkErrorCode(c, err, errno.ErrBadDB) + checkErrorCode(t, err, errno.ErrBadDB) _, err = txn2.Exec("select * from tbl_not_exists;") - checkErrorCode(c, err, errno.ErrNoSuchTable) + checkErrorCode(t, err, errno.ErrNoSuchTable) _, err = txn2.Exec("create database test;") // Make tests stable. Some times the error may be the ErrInfoSchemaChanged. - checkErrorCode(c, err, errno.ErrDBCreateExists, errno.ErrInfoSchemaChanged) + checkErrorCode(t, err, errno.ErrDBCreateExists, errno.ErrInfoSchemaChanged) _, err = txn2.Exec("create database aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa;") - checkErrorCode(c, err, errno.ErrTooLongIdent, errno.ErrInfoSchemaChanged) + checkErrorCode(t, err, errno.ErrTooLongIdent, errno.ErrInfoSchemaChanged) _, err = txn2.Exec("create table test (c int);") - checkErrorCode(c, err, errno.ErrTableExists, errno.ErrInfoSchemaChanged) + checkErrorCode(t, err, errno.ErrTableExists, errno.ErrInfoSchemaChanged) _, err = txn2.Exec("drop table unknown_table;") - checkErrorCode(c, err, errno.ErrBadTable, errno.ErrInfoSchemaChanged) + checkErrorCode(t, err, errno.ErrBadTable, errno.ErrInfoSchemaChanged) _, err = txn2.Exec("drop database unknown_db;") - checkErrorCode(c, err, errno.ErrDBDropExists, errno.ErrInfoSchemaChanged) + checkErrorCode(t, err, errno.ErrDBDropExists, errno.ErrInfoSchemaChanged) _, err = txn2.Exec("create table aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa (a int);") - checkErrorCode(c, err, errno.ErrTooLongIdent, errno.ErrInfoSchemaChanged) + checkErrorCode(t, err, errno.ErrTooLongIdent, errno.ErrInfoSchemaChanged) _, err = txn2.Exec("create table long_column_table (aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa int);") - checkErrorCode(c, err, errno.ErrTooLongIdent, errno.ErrInfoSchemaChanged) + checkErrorCode(t, err, errno.ErrTooLongIdent, errno.ErrInfoSchemaChanged) _, err = txn2.Exec("alter table test add aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa int;") - checkErrorCode(c, err, errno.ErrTooLongIdent, errno.ErrInfoSchemaChanged) + checkErrorCode(t, err, errno.ErrTooLongIdent, errno.ErrInfoSchemaChanged) // Optimizer errors _, err = txn2.Exec("select *, * from test;") - checkErrorCode(c, err, errno.ErrInvalidWildCard) + checkErrorCode(t, err, errno.ErrInvalidWildCard) _, err = txn2.Exec("select row(1, 2) > 1;") - checkErrorCode(c, err, errno.ErrOperandColumns) + checkErrorCode(t, err, errno.ErrOperandColumns) _, err = txn2.Exec("select * from test order by row(c, c);") - checkErrorCode(c, err, errno.ErrOperandColumns) + checkErrorCode(t, err, errno.ErrOperandColumns) // Variable errors _, err = txn2.Exec("select @@unknown_sys_var;") - checkErrorCode(c, err, errno.ErrUnknownSystemVariable) + checkErrorCode(t, err, errno.ErrUnknownSystemVariable) _, err = txn2.Exec("set @@unknown_sys_var='1';") - checkErrorCode(c, err, errno.ErrUnknownSystemVariable) + checkErrorCode(t, err, errno.ErrUnknownSystemVariable) // Expression errors _, err = txn2.Exec("select greatest(2);") - checkErrorCode(c, err, errno.ErrWrongParamcountToNativeFct) + checkErrorCode(t, err, errno.ErrWrongParamcountToNativeFct) }) } -func checkErrorCode(c *C, e error, codes ...uint16) { +func checkErrorCode(t *testing.T, e error, codes ...uint16) { me, ok := e.(*mysql.MySQLError) - c.Assert(ok, IsTrue, Commentf("err: %v", e)) + require.Truef(t, ok, "err: %v", e) if len(codes) == 1 { - c.Assert(me.Number, Equals, codes[0]) + require.Equal(t, codes[0], me.Number) } isMatchCode := false for _, code := range codes { @@ -1597,7 +1598,7 @@ func checkErrorCode(c *C, e error, codes ...uint16) { break } } - c.Assert(isMatchCode, IsTrue, Commentf("got err %v, expected err codes %v", me, codes)) + require.Truef(t, isMatchCode, "got err %v, expected err codes %v", me, codes) } func (cli *testingServerClient) runTestAuth(t *testing.T) { diff --git a/server/tidb_serial_test.go b/server/tidb_serial_test.go new file mode 100644 index 0000000000000..0830caff1bd07 --- /dev/null +++ b/server/tidb_serial_test.go @@ -0,0 +1,28 @@ +// Copyright 2021 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package server + +import "testing" + +// this test will change `kv.TxnTotalSizeLimit` which may affect other test suites, +// so we must make it running in serial. +func TestLoadData(t *testing.T) { + ts, cleanup := createTiDBTest(t) + defer cleanup() + + ts.runTestLoadData(t, ts.server) + ts.runTestLoadDataWithSelectIntoOutfile(t, ts.server) + ts.runTestLoadDataForSlowLog(t, ts.server) +} diff --git a/server/tidb_test.go b/server/tidb_test.go index 1aa093ffa229e..1b93dd228f550 100644 --- a/server/tidb_test.go +++ b/server/tidb_test.go @@ -275,14 +275,6 @@ func (ts *tidbTestSerialSuite) TestConfigDefaultValue(c *C) { }) } -// this test will change `kv.TxnTotalSizeLimit` which may affect other test suites, -// so we must make it running in serial. -func (ts *tidbTestSerialSuite) TestLoadData(c *C) { - ts.runTestLoadData(c, ts.server) - ts.runTestLoadDataWithSelectIntoOutfile(c, ts.server) - ts.runTestLoadDataForSlowLog(c, ts.server) -} - func (ts *tidbTestSerialSuite) TestLoadDataListPartition(c *C) { ts.runTestLoadDataForListPartition(c) ts.runTestLoadDataForListPartition2(c) @@ -316,9 +308,12 @@ func TestConcurrentUpdate(t *testing.T) { ts.runTestConcurrentUpdate(t) } -func (ts *tidbTestSuite) TestErrorCode(c *C) { - c.Parallel() - ts.runTestErrorCode(c) +func TestErrorCode(t *testing.T) { + t.Parallel() + ts, cleanup := createTiDBTest(t) + defer cleanup() + + ts.runTestErrorCode(t) } func TestAuth(t *testing.T) { diff --git a/testkit/dbtestkit.go b/testkit/dbtestkit.go index f7d6de4d07244..f031c9b92bcc3 100644 --- a/testkit/dbtestkit.go +++ b/testkit/dbtestkit.go @@ -42,6 +42,24 @@ func NewDBTestKit(t *testing.T, db *sql.DB) *DBTestKit { } } +func (tk *DBTestKit) MustPrepare(query string) *sql.Stmt { + stmt, err := tk.db.Prepare(query) + tk.require.NoErrorf(err, "Prepare %s", query) + return stmt +} + +func (tk *DBTestKit) MustExecPrepared(stmt *sql.Stmt, args ...interface{}) sql.Result { + res, err := stmt.Exec(args...) + tk.require.NoErrorf(err, "Execute prepared with args: %s", args) + return res +} + +func (tk *DBTestKit) MustQueryPrepared(stmt *sql.Stmt, args ...interface{}) *sql.Rows { + rows, err := stmt.Query(args...) + tk.require.NoErrorf(err, "Query prepared with args: %s", args) + return rows +} + // MustExec query the statements and returns the result. func (tk *DBTestKit) MustExec(sql string, args ...interface{}) sql.Result { comment := fmt.Sprintf("sql:%s, args:%v", sql, args) @@ -60,6 +78,12 @@ func (tk *DBTestKit) MustQuery(sql string, args ...interface{}) *sql.Rows { return rows } +func (tk *DBTestKit) MustQueryRows(query string, args ...interface{}) { + rows := tk.MustQuery(query, args...) + tk.require.True(rows.Next()) + rows.Close() +} + // GetDB returns the underlay sql.DB instance. func (tk *DBTestKit) GetDB() *sql.DB { return tk.db From 0476bd7cd74dc29b28a81d224bb9b569ec479af7 Mon Sep 17 00:00:00 2001 From: yedamo Date: Sun, 7 Nov 2021 17:29:56 +0800 Subject: [PATCH 18/55] server: migrate `TestConfigDefaultValue` --- server/server_test.go | 31 +++++++++++++++++++++++++++++++ server/tidb_serial_test.go | 16 +++++++++++++++- server/tidb_test.go | 7 ------- 3 files changed, 46 insertions(+), 8 deletions(-) diff --git a/server/server_test.go b/server/server_test.go index 6e23ad469ea6f..c06166946932c 100644 --- a/server/server_test.go +++ b/server/server_test.go @@ -920,6 +920,37 @@ func (cli *testServerClient) checkRows(c *C, rows *sql.Rows, expectedRows ...str c.Assert(strings.Join(result, "\n"), Equals, strings.Join(expectedRows, "\n")) } +func (cli *testingServerClient) checkRows(t *testing.T, rows *sql.Rows, expectedRows ...string) { + buf := bytes.NewBuffer(nil) + result := make([]string, 0, 2) + for rows.Next() { + cols, err := rows.Columns() + require.NoError(t, err) + rawResult := make([][]byte, len(cols)) + dest := make([]interface{}, len(cols)) + for i := range rawResult { + dest[i] = &rawResult[i] + } + + err = rows.Scan(dest...) + require.NoError(t, err) + buf.Reset() + for i, raw := range rawResult { + if i > 0 { + buf.WriteString(" ") + } + if raw == nil { + buf.WriteString("") + } else { + buf.WriteString(string(raw)) + } + } + result = append(result, buf.String()) + } + + require.Equal(t, strings.Join(expectedRows, "\n"), strings.Join(result, "\n")) +} + func (cli *testingServerClient) runTestLoadData(t *testing.T, server *Server) { // create a file and write data. path := "/tmp/load_data_test.csv" diff --git a/server/tidb_serial_test.go b/server/tidb_serial_test.go index 0830caff1bd07..a7c0c650c711f 100644 --- a/server/tidb_serial_test.go +++ b/server/tidb_serial_test.go @@ -14,7 +14,11 @@ package server -import "testing" +import ( + "testing" + + "github.com/pingcap/tidb/testkit" +) // this test will change `kv.TxnTotalSizeLimit` which may affect other test suites, // so we must make it running in serial. @@ -26,3 +30,13 @@ func TestLoadData(t *testing.T) { ts.runTestLoadDataWithSelectIntoOutfile(t, ts.server) ts.runTestLoadDataForSlowLog(t, ts.server) } + +func TestConfigDefaultValue(t *testing.T) { + ts, cleanup := createTiDBTest(t) + defer cleanup() + + ts.runTestsOnNewDB(t, nil, "config", func(dbt *testkit.DBTestKit) { + rows := dbt.MustQuery("select @@tidb_slow_log_threshold;") + ts.checkRows(t, rows, "300") + }) +} diff --git a/server/tidb_test.go b/server/tidb_test.go index 1b93dd228f550..9c19fb52eee5a 100644 --- a/server/tidb_test.go +++ b/server/tidb_test.go @@ -268,13 +268,6 @@ func TestPreparedTimestamp(t *testing.T) { ts.runTestPreparedTimestamp(t) } -func (ts *tidbTestSerialSuite) TestConfigDefaultValue(c *C) { - ts.runTestsOnNewDB(c, nil, "config", func(dbt *DBTest) { - rows := dbt.mustQuery("select @@tidb_slow_log_threshold;") - ts.checkRows(c, rows, "300") - }) -} - func (ts *tidbTestSerialSuite) TestLoadDataListPartition(c *C) { ts.runTestLoadDataForListPartition(c) ts.runTestLoadDataForListPartition2(c) From 45b28de95ad4f898d9aeb881017bd6a4ef131457 Mon Sep 17 00:00:00 2001 From: yedamo Date: Sun, 7 Nov 2021 17:48:36 +0800 Subject: [PATCH 19/55] server: migrate `TestLoadDataAutoRandom` --- server/server_test.go | 32 ++++++++++++++++---------------- server/tidb_serial_test.go | 9 +++++++++ server/tidb_test.go | 6 ------ 3 files changed, 25 insertions(+), 22 deletions(-) diff --git a/server/server_test.go b/server/server_test.go index c06166946932c..7bbb8f1aaffd1 100644 --- a/server/server_test.go +++ b/server/server_test.go @@ -577,12 +577,12 @@ func (cli *testServerClient) prepareLoadDataFile(c *C, path string, rows ...stri c.Assert(err, IsNil) } -func (cli *testServerClient) runTestLoadDataAutoRandom(c *C) { +func (cli *testingServerClient) runTestLoadDataAutoRandom(t *testing.T) { path := "/tmp/load_data_txn_error.csv" fp, err := os.OpenFile(path, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0600) - c.Assert(err, IsNil) - c.Assert(fp, NotNil) + require.NoError(t, err) + require.NotNil(t, fp) defer func() { _ = os.Remove(path) @@ -597,9 +597,9 @@ func (cli *testServerClient) runTestLoadDataAutoRandom(c *C) { str2 := strconv.Itoa(n2) row := str1 + "\t" + str2 _, err := fp.WriteString(row) - c.Assert(err, IsNil) + require.NoError(t, err) _, err = fp.WriteString("\n") - c.Assert(err, IsNil) + require.NoError(t, err) if i == 0 { cksum1 = n1 @@ -611,24 +611,24 @@ func (cli *testServerClient) runTestLoadDataAutoRandom(c *C) { } err = fp.Close() - c.Assert(err, IsNil) + require.NoError(t, err) - cli.runTestsOnNewDB(c, func(config *mysql.Config) { + cli.runTestsOnNewDB(t, func(config *mysql.Config) { config.AllowAllFiles = true config.Params["sql_mode"] = "''" - }, "load_data_batch_dml", func(dbt *DBTest) { + }, "load_data_batch_dml", func(dbt *testkit.DBTestKit) { // Set batch size, and check if load data got a invalid txn error. - dbt.mustExec("set @@session.tidb_dml_batch_size = 128") - dbt.mustExec("drop table if exists t") - dbt.mustExec("create table t(c1 bigint auto_random primary key, c2 bigint, c3 bigint)") - dbt.mustExec(fmt.Sprintf("load data local infile %q into table t (c2, c3)", path)) - rows := dbt.mustQuery("select count(*) from t") - cli.checkRows(c, rows, "50000") - rows = dbt.mustQuery("select bit_xor(c2), bit_xor(c3) from t") + dbt.MustExec("set @@session.tidb_dml_batch_size = 128") + dbt.MustExec("drop table if exists t") + dbt.MustExec("create table t(c1 bigint auto_random primary key, c2 bigint, c3 bigint)") + dbt.MustExec(fmt.Sprintf("load data local infile %q into table t (c2, c3)", path)) + rows := dbt.MustQuery("select count(*) from t") + cli.checkRows(t, rows, "50000") + rows = dbt.MustQuery("select bit_xor(c2), bit_xor(c3) from t") res := strconv.Itoa(cksum1) res = res + " " res = res + strconv.Itoa(cksum2) - cli.checkRows(c, rows, res) + cli.checkRows(t, rows, res) }) } diff --git a/server/tidb_serial_test.go b/server/tidb_serial_test.go index a7c0c650c711f..38ab4be4b1a51 100644 --- a/server/tidb_serial_test.go +++ b/server/tidb_serial_test.go @@ -40,3 +40,12 @@ func TestConfigDefaultValue(t *testing.T) { ts.checkRows(t, rows, "300") }) } + +// Fix issue#22540. Change tidb_dml_batch_size, +// then check if load data into table with auto random column works properly. +func TestLoadDataAutoRandom(t *testing.T) { + ts, cleanup := createTiDBTest(t) + defer cleanup() + + ts.runTestLoadDataAutoRandom(t) +} diff --git a/server/tidb_test.go b/server/tidb_test.go index 9c19fb52eee5a..76b1081ecd5a1 100644 --- a/server/tidb_test.go +++ b/server/tidb_test.go @@ -275,12 +275,6 @@ func (ts *tidbTestSerialSuite) TestLoadDataListPartition(c *C) { ts.runTestLoadDataForListColumnPartition2(c) } -// Fix issue#22540. Change tidb_dml_batch_size, -// then check if load data into table with auto random column works properly. -func (ts *tidbTestSerialSuite) TestLoadDataAutoRandom(c *C) { - ts.runTestLoadDataAutoRandom(c) -} - func (ts *tidbTestSerialSuite) TestLoadDataAutoRandomWithSpecialTerm(c *C) { ts.runTestLoadDataAutoRandomWithSpecialTerm(c) } From 4a00c682e29dc9d6857223b4e1939285e8737d49 Mon Sep 17 00:00:00 2001 From: yedamo Date: Sun, 7 Nov 2021 17:58:34 +0800 Subject: [PATCH 20/55] server: migrate `TestExplainFor` and `TestLoadDataAutoRandomWithSpecialTerm` --- server/server_test.go | 62 +++++++++++++++++++------------------- server/tidb_serial_test.go | 14 +++++++++ server/tidb_test.go | 8 ----- 3 files changed, 45 insertions(+), 39 deletions(-) diff --git a/server/server_test.go b/server/server_test.go index 7bbb8f1aaffd1..b0ac7a67dd0b2 100644 --- a/server/server_test.go +++ b/server/server_test.go @@ -632,12 +632,12 @@ func (cli *testingServerClient) runTestLoadDataAutoRandom(t *testing.T) { }) } -func (cli *testServerClient) runTestLoadDataAutoRandomWithSpecialTerm(c *C) { +func (cli *testingServerClient) runTestLoadDataAutoRandomWithSpecialTerm(t *testing.T) { path := "/tmp/load_data_txn_error_term.csv" fp, err := os.OpenFile(path, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0600) - c.Assert(err, IsNil) - c.Assert(fp, NotNil) + require.NoError(t, err) + require.NotNil(t, fp) defer func() { _ = os.Remove(path) @@ -652,11 +652,11 @@ func (cli *testServerClient) runTestLoadDataAutoRandomWithSpecialTerm(c *C) { str2 := strconv.Itoa(n2) row := "'" + str1 + "','" + str2 + "'" _, err := fp.WriteString(row) - c.Assert(err, IsNil) + require.NoError(t, err) if i != 49999 { _, err = fp.WriteString("|") } - c.Assert(err, IsNil) + require.NoError(t, err) if i == 0 { cksum1 = n1 @@ -668,24 +668,24 @@ func (cli *testServerClient) runTestLoadDataAutoRandomWithSpecialTerm(c *C) { } err = fp.Close() - c.Assert(err, IsNil) + require.NoError(t, err) - cli.runTestsOnNewDB(c, func(config *mysql.Config) { + cli.runTestsOnNewDB(t, func(config *mysql.Config) { config.AllowAllFiles = true config.Params = map[string]string{"sql_mode": "''"} - }, "load_data_batch_dml", func(dbt *DBTest) { + }, "load_data_batch_dml", func(dbt *testkit.DBTestKit) { // Set batch size, and check if load data got a invalid txn error. - dbt.mustExec("set @@session.tidb_dml_batch_size = 128") - dbt.mustExec("drop table if exists t1") - dbt.mustExec("create table t1(c1 bigint auto_random primary key, c2 bigint, c3 bigint)") - dbt.mustExec(fmt.Sprintf("load data local infile %q into table t1 fields terminated by ',' enclosed by '\\'' lines terminated by '|' (c2, c3)", path)) - rows := dbt.mustQuery("select count(*) from t1") - cli.checkRows(c, rows, "50000") - rows = dbt.mustQuery("select bit_xor(c2), bit_xor(c3) from t1") + dbt.MustExec("set @@session.tidb_dml_batch_size = 128") + dbt.MustExec("drop table if exists t1") + dbt.MustExec("create table t1(c1 bigint auto_random primary key, c2 bigint, c3 bigint)") + dbt.MustExec(fmt.Sprintf("load data local infile %q into table t1 fields terminated by ',' enclosed by '\\'' lines terminated by '|' (c2, c3)", path)) + rows := dbt.MustQuery("select count(*) from t1") + cli.checkRows(t, rows, "50000") + rows = dbt.MustQuery("select bit_xor(c2), bit_xor(c3) from t1") res := strconv.Itoa(cksum1) res = res + " " res = res + strconv.Itoa(cksum2) - cli.checkRows(c, rows, res) + cli.checkRows(t, rows, res) }) } @@ -1538,25 +1538,25 @@ func (cli *testingServerClient) runTestConcurrentUpdate(t *testing.T) { }) } -func (cli *testServerClient) runTestExplainForConn(c *C) { - cli.runTestsOnNewDB(c, nil, "explain_for_conn", func(dbt *DBTest) { - dbt.mustExec("drop table if exists t") - dbt.mustExec("create table t (a int key, b int)") - dbt.mustExec("insert t values (1, 1)") - rows := dbt.mustQuery("select connection_id();") - c.Assert(rows.Next(), IsTrue) +func (cli *testingServerClient) runTestExplainForConn(t *testing.T) { + cli.runTestsOnNewDB(t, nil, "explain_for_conn", func(dbt *testkit.DBTestKit) { + dbt.MustExec("drop table if exists t") + dbt.MustExec("create table t (a int key, b int)") + dbt.MustExec("insert t values (1, 1)") + rows := dbt.MustQuery("select connection_id();") + require.True(t, rows.Next()) var connID int64 err := rows.Scan(&connID) - c.Assert(err, IsNil) - c.Assert(rows.Close(), IsNil) - dbt.mustQuery("select * from t where a=1") - rows = dbt.mustQuery("explain for connection " + strconv.Itoa(int(connID))) - c.Assert(rows.Next(), IsTrue) + require.NoError(t, err) + require.NoError(t, rows.Close()) + dbt.MustQuery("select * from t where a=1") + rows = dbt.MustQuery("explain for connection " + strconv.Itoa(int(connID))) + require.True(t, rows.Next()) row := make([]string, 9) err = rows.Scan(&row[0], &row[1], &row[2], &row[3], &row[4], &row[5], &row[6], &row[7], &row[8]) - c.Assert(err, IsNil) - c.Assert(strings.Join(row, ","), Matches, "Point_Get_1,1.00,1,root,table:t,time.*loop.*handle:1.*") - c.Assert(rows.Close(), IsNil) + require.NoError(t, err) + require.Regexp(t, "Point_Get_1,1.00,1,root,table:t,time.*loop.*handle:1.*", strings.Join(row, ",")) + require.NoError(t, rows.Close()) }) } diff --git a/server/tidb_serial_test.go b/server/tidb_serial_test.go index 38ab4be4b1a51..6d4f1875946d9 100644 --- a/server/tidb_serial_test.go +++ b/server/tidb_serial_test.go @@ -49,3 +49,17 @@ func TestLoadDataAutoRandom(t *testing.T) { ts.runTestLoadDataAutoRandom(t) } + +func TestLoadDataAutoRandomWithSpecialTerm(t *testing.T) { + ts, cleanup := createTiDBTest(t) + defer cleanup() + + ts.runTestLoadDataAutoRandomWithSpecialTerm(t) +} + +func TestExplainFor(t *testing.T) { + ts, cleanup := createTiDBTest(t) + defer cleanup() + + ts.runTestExplainForConn(t) +} diff --git a/server/tidb_test.go b/server/tidb_test.go index 76b1081ecd5a1..a070235e4417e 100644 --- a/server/tidb_test.go +++ b/server/tidb_test.go @@ -275,14 +275,6 @@ func (ts *tidbTestSerialSuite) TestLoadDataListPartition(c *C) { ts.runTestLoadDataForListColumnPartition2(c) } -func (ts *tidbTestSerialSuite) TestLoadDataAutoRandomWithSpecialTerm(c *C) { - ts.runTestLoadDataAutoRandomWithSpecialTerm(c) -} - -func (ts *tidbTestSerialSuite) TestExplainFor(c *C) { - ts.runTestExplainForConn(c) -} - func (ts *tidbTestSerialSuite) TestStmtCount(c *C) { ts.runTestStmtCount(c) } From 874413d7d3d6c12641d51d77ac3727327b657831 Mon Sep 17 00:00:00 2001 From: yedamo Date: Sun, 7 Nov 2021 18:08:39 +0800 Subject: [PATCH 21/55] server: migrate `TestStmtCount` --- server/server_test.go | 58 +++++++++++++++++++------------------- server/tidb_serial_test.go | 7 +++++ server/tidb_test.go | 4 --- 3 files changed, 36 insertions(+), 33 deletions(-) diff --git a/server/server_test.go b/server/server_test.go index b0ac7a67dd0b2..1c6891ed1bd2d 100644 --- a/server/server_test.go +++ b/server/server_test.go @@ -1934,38 +1934,38 @@ func (cli *testServerClient) runTestMultiStatements(c *C) { }) } -func (cli *testServerClient) runTestStmtCount(t *C) { - cli.runTestsOnNewDB(t, nil, "StatementCount", func(dbt *DBTest) { +func (cli *testingServerClient) runTestStmtCount(t *testing.T) { + cli.runTestsOnNewDB(t, nil, "StatementCount", func(dbt *testkit.DBTestKit) { originStmtCnt := getStmtCnt(string(cli.getMetrics(t))) - dbt.mustExec("create table test (a int)") + dbt.MustExec("create table test (a int)") - dbt.mustExec("insert into test values(1)") - dbt.mustExec("insert into test values(2)") - dbt.mustExec("insert into test values(3)") - dbt.mustExec("insert into test values(4)") - dbt.mustExec("insert into test values(5)") + dbt.MustExec("insert into test values(1)") + dbt.MustExec("insert into test values(2)") + dbt.MustExec("insert into test values(3)") + dbt.MustExec("insert into test values(4)") + dbt.MustExec("insert into test values(5)") - dbt.mustExec("delete from test where a = 3") - dbt.mustExec("update test set a = 2 where a = 1") - dbt.mustExec("select * from test") - dbt.mustExec("select 2") + dbt.MustExec("delete from test where a = 3") + dbt.MustExec("update test set a = 2 where a = 1") + dbt.MustExec("select * from test") + dbt.MustExec("select 2") - dbt.mustExec("prepare stmt1 from 'update test set a = 1 where a = 2'") - dbt.mustExec("execute stmt1") - dbt.mustExec("prepare stmt2 from 'select * from test'") - dbt.mustExec("execute stmt2") - dbt.mustExec("replace into test(a) values(6);") + dbt.MustExec("prepare stmt1 from 'update test set a = 1 where a = 2'") + dbt.MustExec("execute stmt1") + dbt.MustExec("prepare stmt2 from 'select * from test'") + dbt.MustExec("execute stmt2") + dbt.MustExec("replace into test(a) values(6);") currentStmtCnt := getStmtCnt(string(cli.getMetrics(t))) - t.Assert(currentStmtCnt["CreateTable"], Equals, originStmtCnt["CreateTable"]+1) - t.Assert(currentStmtCnt["Insert"], Equals, originStmtCnt["Insert"]+5) - t.Assert(currentStmtCnt["Delete"], Equals, originStmtCnt["Delete"]+1) - t.Assert(currentStmtCnt["Update"], Equals, originStmtCnt["Update"]+1) - t.Assert(currentStmtCnt["Select"], Equals, originStmtCnt["Select"]+2) - t.Assert(currentStmtCnt["Prepare"], Equals, originStmtCnt["Prepare"]+2) - t.Assert(currentStmtCnt["Execute"], Equals, originStmtCnt["Execute"]+2) - t.Assert(currentStmtCnt["Replace"], Equals, originStmtCnt["Replace"]+1) + require.Equal(t, originStmtCnt["CreateTable"]+1, currentStmtCnt["CreateTable"]) + require.Equal(t, originStmtCnt["Insert"]+5, currentStmtCnt["Insert"]) + require.Equal(t, originStmtCnt["Delete"]+1, currentStmtCnt["Delete"]) + require.Equal(t, originStmtCnt["Update"]+1, currentStmtCnt["Update"]) + require.Equal(t, originStmtCnt["Select"]+2, currentStmtCnt["Select"]) + require.Equal(t, originStmtCnt["Prepare"]+2, currentStmtCnt["Prepare"]) + require.Equal(t, originStmtCnt["Execute"]+2, currentStmtCnt["Execute"]) + require.Equal(t, originStmtCnt["Replace"]+1, currentStmtCnt["Replace"]) }) } @@ -2021,13 +2021,13 @@ func (cli *testServerClient) runTestSumAvg(c *C) { }) } -func (cli *testServerClient) getMetrics(t *C) []byte { +func (cli *testingServerClient) getMetrics(t *testing.T) []byte { resp, err := cli.fetchStatus("/metrics") - t.Assert(err, IsNil) + require.NoError(t, err) content, err := io.ReadAll(resp.Body) - t.Assert(err, IsNil) + require.NoError(t, err) err = resp.Body.Close() - t.Assert(err, IsNil) + require.NoError(t, err) return content } diff --git a/server/tidb_serial_test.go b/server/tidb_serial_test.go index 6d4f1875946d9..aa82e6f1b70d9 100644 --- a/server/tidb_serial_test.go +++ b/server/tidb_serial_test.go @@ -63,3 +63,10 @@ func TestExplainFor(t *testing.T) { ts.runTestExplainForConn(t) } + +func TestStmtCount(t *testing.T) { + ts, cleanup := createTiDBTest(t) + defer cleanup() + + ts.runTestStmtCount(t) +} diff --git a/server/tidb_test.go b/server/tidb_test.go index a070235e4417e..1d6c3279dd52d 100644 --- a/server/tidb_test.go +++ b/server/tidb_test.go @@ -275,10 +275,6 @@ func (ts *tidbTestSerialSuite) TestLoadDataListPartition(c *C) { ts.runTestLoadDataForListColumnPartition2(c) } -func (ts *tidbTestSerialSuite) TestStmtCount(c *C) { - ts.runTestStmtCount(c) -} - func TestConcurrentUpdate(t *testing.T) { t.Parallel() ts, cleanup := createTiDBTest(t) From 5fbfadf38187c72871ed1dbfc77c315c15c08677 Mon Sep 17 00:00:00 2001 From: yedamo Date: Sun, 7 Nov 2021 18:26:00 +0800 Subject: [PATCH 22/55] server: migrate `TestLoadDataListPartition` --- server/server_test.go | 238 ++++++++++++++++++------------------- server/tidb_serial_test.go | 10 ++ server/tidb_test.go | 7 -- 3 files changed, 129 insertions(+), 126 deletions(-) diff --git a/server/server_test.go b/server/server_test.go index 1c6891ed1bd2d..2e9592bee8d80 100644 --- a/server/server_test.go +++ b/server/server_test.go @@ -561,20 +561,20 @@ func (cli *testingServerClient) runTestLoadDataForSlowLog(t *testing.T, server * }) } -func (cli *testServerClient) prepareLoadDataFile(c *C, path string, rows ...string) { +func (cli *testingServerClient) prepareLoadDataFile(t *testing.T, path string, rows ...string) { fp, err := os.OpenFile(path, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0600) - c.Assert(err, IsNil) - c.Assert(fp, NotNil) + require.NoError(t, err) + require.NotNil(t, fp) defer func() { err = fp.Close() - c.Assert(err, IsNil) + require.NoError(t, err) }() for _, row := range rows { fields := strings.Split(row, " ") _, err = fp.WriteString(strings.Join(fields, "\t")) _, err = fp.WriteString("\n") } - c.Assert(err, IsNil) + require.NoError(t, err) } func (cli *testingServerClient) runTestLoadDataAutoRandom(t *testing.T) { @@ -689,18 +689,18 @@ func (cli *testingServerClient) runTestLoadDataAutoRandomWithSpecialTerm(t *test }) } -func (cli *testServerClient) runTestLoadDataForListPartition(c *C) { +func (cli *testingServerClient) runTestLoadDataForListPartition(t *testing.T) { path := "/tmp/load_data_list_partition.csv" defer func() { _ = os.Remove(path) }() - cli.runTestsOnNewDB(c, func(config *mysql.Config) { + cli.runTestsOnNewDB(t, func(config *mysql.Config) { config.AllowAllFiles = true config.Params["sql_mode"] = "''" - }, "load_data_list_partition", func(dbt *DBTest) { - dbt.mustExec("set @@session.tidb_enable_list_partition = ON") - dbt.mustExec(`create table t (id int, name varchar(10), + }, "load_data_list_partition", func(dbt *testkit.DBTestKit) { + dbt.MustExec("set @@session.tidb_enable_list_partition = ON") + dbt.MustExec(`create table t (id int, name varchar(10), unique index idx (id)) partition by list (id) ( partition p0 values in (3,5,6,9,17), partition p1 values in (1,2,10,11,19,20), @@ -708,48 +708,48 @@ func (cli *testServerClient) runTestLoadDataForListPartition(c *C) { partition p3 values in (7,8,15,16,null) );`) // Test load data into 1 partition. - cli.prepareLoadDataFile(c, path, "1 a", "2 b") - dbt.mustExec(fmt.Sprintf("load data local infile %q into table t", path)) - rows := dbt.mustQuery("select * from t partition(p1) order by id") - cli.checkRows(c, rows, "1 a", "2 b") + cli.prepareLoadDataFile(t, path, "1 a", "2 b") + dbt.MustExec(fmt.Sprintf("load data local infile %q into table t", path)) + rows := dbt.MustQuery("select * from t partition(p1) order by id") + cli.checkRows(t, rows, "1 a", "2 b") // Test load data into multi-partitions. - dbt.mustExec("delete from t") - cli.prepareLoadDataFile(c, path, "1 a", "3 c", "4 e") - dbt.mustExec(fmt.Sprintf("load data local infile %q into table t", path)) - rows = dbt.mustQuery("select * from t order by id") - cli.checkRows(c, rows, "1 a", "3 c", "4 e") + dbt.MustExec("delete from t") + cli.prepareLoadDataFile(t, path, "1 a", "3 c", "4 e") + dbt.MustExec(fmt.Sprintf("load data local infile %q into table t", path)) + rows = dbt.MustQuery("select * from t order by id") + cli.checkRows(t, rows, "1 a", "3 c", "4 e") // Test load data meet duplicate error. - cli.prepareLoadDataFile(c, path, "1 x", "2 b", "2 x", "7 a") - dbt.mustExec(fmt.Sprintf("load data local infile %q into table t", path)) - rows = dbt.mustQuery("show warnings") - cli.checkRows(c, rows, + cli.prepareLoadDataFile(t, path, "1 x", "2 b", "2 x", "7 a") + dbt.MustExec(fmt.Sprintf("load data local infile %q into table t", path)) + rows = dbt.MustQuery("show warnings") + cli.checkRows(t, rows, "Warning 1062 Duplicate entry '1' for key 'idx'", "Warning 1062 Duplicate entry '2' for key 'idx'") - rows = dbt.mustQuery("select * from t order by id") - cli.checkRows(c, rows, "1 a", "2 b", "3 c", "4 e", "7 a") + rows = dbt.MustQuery("select * from t order by id") + cli.checkRows(t, rows, "1 a", "2 b", "3 c", "4 e", "7 a") // Test load data meet no partition warning. - cli.prepareLoadDataFile(c, path, "5 a", "100 x") - _, err := dbt.db.Exec(fmt.Sprintf("load data local infile %q into table t", path)) - c.Assert(err, IsNil) - rows = dbt.mustQuery("show warnings") - cli.checkRows(c, rows, "Warning 1526 Table has no partition for value 100") - rows = dbt.mustQuery("select * from t order by id") - cli.checkRows(c, rows, "1 a", "2 b", "3 c", "4 e", "5 a", "7 a") + cli.prepareLoadDataFile(t, path, "5 a", "100 x") + _, err := dbt.GetDB().Exec(fmt.Sprintf("load data local infile %q into table t", path)) + require.NoError(t, err) + rows = dbt.MustQuery("show warnings") + cli.checkRows(t, rows, "Warning 1526 Table has no partition for value 100") + rows = dbt.MustQuery("select * from t order by id") + cli.checkRows(t, rows, "1 a", "2 b", "3 c", "4 e", "5 a", "7 a") }) } -func (cli *testServerClient) runTestLoadDataForListPartition2(c *C) { +func (cli *testingServerClient) runTestLoadDataForListPartition2(t *testing.T) { path := "/tmp/load_data_list_partition.csv" defer func() { _ = os.Remove(path) }() - cli.runTestsOnNewDB(c, func(config *mysql.Config) { + cli.runTestsOnNewDB(t, func(config *mysql.Config) { config.AllowAllFiles = true config.Params["sql_mode"] = "''" - }, "load_data_list_partition", func(dbt *DBTest) { - dbt.mustExec("set @@session.tidb_enable_list_partition = ON") - dbt.mustExec(`create table t (id int, name varchar(10),b int generated always as (length(name)+1) virtual, + }, "load_data_list_partition", func(dbt *testkit.DBTestKit) { + dbt.MustExec("set @@session.tidb_enable_list_partition = ON") + dbt.MustExec(`create table t (id int, name varchar(10),b int generated always as (length(name)+1) virtual, unique index idx (id,b)) partition by list (id*2 + b*b + b*b - b*b*2 - abs(id)) ( partition p0 values in (3,5,6,9,17), partition p1 values in (1,2,10,11,19,20), @@ -757,48 +757,48 @@ func (cli *testServerClient) runTestLoadDataForListPartition2(c *C) { partition p3 values in (7,8,15,16,null) );`) // Test load data into 1 partition. - cli.prepareLoadDataFile(c, path, "1 a", "2 b") - dbt.mustExec(fmt.Sprintf("load data local infile %q into table t (id,name)", path)) - rows := dbt.mustQuery("select id,name from t partition(p1) order by id") - cli.checkRows(c, rows, "1 a", "2 b") + cli.prepareLoadDataFile(t, path, "1 a", "2 b") + dbt.MustExec(fmt.Sprintf("load data local infile %q into table t (id,name)", path)) + rows := dbt.MustQuery("select id,name from t partition(p1) order by id") + cli.checkRows(t, rows, "1 a", "2 b") // Test load data into multi-partitions. - dbt.mustExec("delete from t") - cli.prepareLoadDataFile(c, path, "1 a", "3 c", "4 e") - dbt.mustExec(fmt.Sprintf("load data local infile %q into table t (id,name)", path)) - rows = dbt.mustQuery("select id,name from t order by id") - cli.checkRows(c, rows, "1 a", "3 c", "4 e") + dbt.MustExec("delete from t") + cli.prepareLoadDataFile(t, path, "1 a", "3 c", "4 e") + dbt.MustExec(fmt.Sprintf("load data local infile %q into table t (id,name)", path)) + rows = dbt.MustQuery("select id,name from t order by id") + cli.checkRows(t, rows, "1 a", "3 c", "4 e") // Test load data meet duplicate error. - cli.prepareLoadDataFile(c, path, "1 x", "2 b", "2 x", "7 a") - dbt.mustExec(fmt.Sprintf("load data local infile %q into table t (id,name)", path)) - rows = dbt.mustQuery("show warnings") - cli.checkRows(c, rows, + cli.prepareLoadDataFile(t, path, "1 x", "2 b", "2 x", "7 a") + dbt.MustExec(fmt.Sprintf("load data local infile %q into table t (id,name)", path)) + rows = dbt.MustQuery("show warnings") + cli.checkRows(t, rows, "Warning 1062 Duplicate entry '1-2' for key 'idx'", "Warning 1062 Duplicate entry '2-2' for key 'idx'") - rows = dbt.mustQuery("select id,name from t order by id") - cli.checkRows(c, rows, "1 a", "2 b", "3 c", "4 e", "7 a") + rows = dbt.MustQuery("select id,name from t order by id") + cli.checkRows(t, rows, "1 a", "2 b", "3 c", "4 e", "7 a") // Test load data meet no partition warning. - cli.prepareLoadDataFile(c, path, "5 a", "100 x") - _, err := dbt.db.Exec(fmt.Sprintf("load data local infile %q into table t (id,name)", path)) - c.Assert(err, IsNil) - rows = dbt.mustQuery("show warnings") - cli.checkRows(c, rows, "Warning 1526 Table has no partition for value 100") - rows = dbt.mustQuery("select id,name from t order by id") - cli.checkRows(c, rows, "1 a", "2 b", "3 c", "4 e", "5 a", "7 a") + cli.prepareLoadDataFile(t, path, "5 a", "100 x") + _, err := dbt.GetDB().Exec(fmt.Sprintf("load data local infile %q into table t (id,name)", path)) + require.NoError(t, err) + rows = dbt.MustQuery("show warnings") + cli.checkRows(t, rows, "Warning 1526 Table has no partition for value 100") + rows = dbt.MustQuery("select id,name from t order by id") + cli.checkRows(t, rows, "1 a", "2 b", "3 c", "4 e", "5 a", "7 a") }) } -func (cli *testServerClient) runTestLoadDataForListColumnPartition(c *C) { +func (cli *testingServerClient) runTestLoadDataForListColumnPartition(t *testing.T) { path := "/tmp/load_data_list_partition.csv" defer func() { _ = os.Remove(path) }() - cli.runTestsOnNewDB(c, func(config *mysql.Config) { + cli.runTestsOnNewDB(t, func(config *mysql.Config) { config.AllowAllFiles = true config.Params["sql_mode"] = "''" - }, "load_data_list_partition", func(dbt *DBTest) { - dbt.mustExec("set @@session.tidb_enable_list_partition = ON") - dbt.mustExec(`create table t (id int, name varchar(10), + }, "load_data_list_partition", func(dbt *testkit.DBTestKit) { + dbt.MustExec("set @@session.tidb_enable_list_partition = ON") + dbt.MustExec(`create table t (id int, name varchar(10), unique index idx (id)) partition by list columns (id) ( partition p0 values in (3,5,6,9,17), partition p1 values in (1,2,10,11,19,20), @@ -806,87 +806,87 @@ func (cli *testServerClient) runTestLoadDataForListColumnPartition(c *C) { partition p3 values in (7,8,15,16,null) );`) // Test load data into 1 partition. - cli.prepareLoadDataFile(c, path, "1 a", "2 b") - dbt.mustExec(fmt.Sprintf("load data local infile %q into table t", path)) - rows := dbt.mustQuery("select * from t partition(p1) order by id") - cli.checkRows(c, rows, "1 a", "2 b") + cli.prepareLoadDataFile(t, path, "1 a", "2 b") + dbt.MustExec(fmt.Sprintf("load data local infile %q into table t", path)) + rows := dbt.MustQuery("select * from t partition(p1) order by id") + cli.checkRows(t, rows, "1 a", "2 b") // Test load data into multi-partitions. - dbt.mustExec("delete from t") - cli.prepareLoadDataFile(c, path, "1 a", "3 c", "4 e") - dbt.mustExec(fmt.Sprintf("load data local infile %q into table t", path)) - rows = dbt.mustQuery("select * from t order by id") - cli.checkRows(c, rows, "1 a", "3 c", "4 e") + dbt.MustExec("delete from t") + cli.prepareLoadDataFile(t, path, "1 a", "3 c", "4 e") + dbt.MustExec(fmt.Sprintf("load data local infile %q into table t", path)) + rows = dbt.MustQuery("select * from t order by id") + cli.checkRows(t, rows, "1 a", "3 c", "4 e") // Test load data meet duplicate error. - cli.prepareLoadDataFile(c, path, "1 x", "2 b", "2 x", "7 a") - dbt.mustExec(fmt.Sprintf("load data local infile %q into table t", path)) - rows = dbt.mustQuery("show warnings") - cli.checkRows(c, rows, + cli.prepareLoadDataFile(t, path, "1 x", "2 b", "2 x", "7 a") + dbt.MustExec(fmt.Sprintf("load data local infile %q into table t", path)) + rows = dbt.MustQuery("show warnings") + cli.checkRows(t, rows, "Warning 1062 Duplicate entry '1' for key 'idx'", "Warning 1062 Duplicate entry '2' for key 'idx'") - rows = dbt.mustQuery("select * from t order by id") - cli.checkRows(c, rows, "1 a", "2 b", "3 c", "4 e", "7 a") + rows = dbt.MustQuery("select * from t order by id") + cli.checkRows(t, rows, "1 a", "2 b", "3 c", "4 e", "7 a") // Test load data meet no partition warning. - cli.prepareLoadDataFile(c, path, "5 a", "100 x") - _, err := dbt.db.Exec(fmt.Sprintf("load data local infile %q into table t", path)) - c.Assert(err, IsNil) - rows = dbt.mustQuery("show warnings") - cli.checkRows(c, rows, "Warning 1526 Table has no partition for value from column_list") - rows = dbt.mustQuery("select id,name from t order by id") - cli.checkRows(c, rows, "1 a", "2 b", "3 c", "4 e", "5 a", "7 a") + cli.prepareLoadDataFile(t, path, "5 a", "100 x") + _, err := dbt.GetDB().Exec(fmt.Sprintf("load data local infile %q into table t", path)) + require.NoError(t, err) + rows = dbt.MustQuery("show warnings") + cli.checkRows(t, rows, "Warning 1526 Table has no partition for value from column_list") + rows = dbt.MustQuery("select id,name from t order by id") + cli.checkRows(t, rows, "1 a", "2 b", "3 c", "4 e", "5 a", "7 a") }) } -func (cli *testServerClient) runTestLoadDataForListColumnPartition2(c *C) { +func (cli *testingServerClient) runTestLoadDataForListColumnPartition2(t *testing.T) { path := "/tmp/load_data_list_partition.csv" defer func() { _ = os.Remove(path) }() - cli.runTestsOnNewDB(c, func(config *mysql.Config) { + cli.runTestsOnNewDB(t, func(config *mysql.Config) { config.AllowAllFiles = true config.Params["sql_mode"] = "''" - }, "load_data_list_partition", func(dbt *DBTest) { - dbt.mustExec("set @@session.tidb_enable_list_partition = ON") - dbt.mustExec(`create table t (location varchar(10), id int, a int, unique index idx (location,id)) partition by list columns (location,id) ( + }, "load_data_list_partition", func(dbt *testkit.DBTestKit) { + dbt.MustExec("set @@session.tidb_enable_list_partition = ON") + dbt.MustExec(`create table t (location varchar(10), id int, a int, unique index idx (location,id)) partition by list columns (location,id) ( partition p_west values in (('w', 1),('w', 2),('w', 3),('w', 4)), partition p_east values in (('e', 5),('e', 6),('e', 7),('e', 8)), partition p_north values in (('n', 9),('n',10),('n',11),('n',12)), partition p_south values in (('s',13),('s',14),('s',15),('s',16)) );`) // Test load data into 1 partition. - cli.prepareLoadDataFile(c, path, "w 1 1", "w 2 2") - dbt.mustExec(fmt.Sprintf("load data local infile %q into table t", path)) - rows := dbt.mustQuery("select * from t partition(p_west) order by id") - cli.checkRows(c, rows, "w 1 1", "w 2 2") + cli.prepareLoadDataFile(t, path, "w 1 1", "w 2 2") + dbt.MustExec(fmt.Sprintf("load data local infile %q into table t", path)) + rows := dbt.MustQuery("select * from t partition(p_west) order by id") + cli.checkRows(t, rows, "w 1 1", "w 2 2") // Test load data into multi-partitions. - dbt.mustExec("delete from t") - cli.prepareLoadDataFile(c, path, "w 1 1", "e 5 5", "n 9 9") - dbt.mustExec(fmt.Sprintf("load data local infile %q into table t", path)) - rows = dbt.mustQuery("select * from t order by id") - cli.checkRows(c, rows, "w 1 1", "e 5 5", "n 9 9") + dbt.MustExec("delete from t") + cli.prepareLoadDataFile(t, path, "w 1 1", "e 5 5", "n 9 9") + dbt.MustExec(fmt.Sprintf("load data local infile %q into table t", path)) + rows = dbt.MustQuery("select * from t order by id") + cli.checkRows(t, rows, "w 1 1", "e 5 5", "n 9 9") // Test load data meet duplicate error. - cli.prepareLoadDataFile(c, path, "w 1 2", "w 2 2") - _, err := dbt.db.Exec(fmt.Sprintf("load data local infile %q into table t", path)) - c.Assert(err, IsNil) - rows = dbt.mustQuery("show warnings") - cli.checkRows(c, rows, "Warning 1062 Duplicate entry 'w-1' for key 'idx'") - rows = dbt.mustQuery("select * from t order by id") - cli.checkRows(c, rows, "w 1 1", "w 2 2", "e 5 5", "n 9 9") + cli.prepareLoadDataFile(t, path, "w 1 2", "w 2 2") + _, err := dbt.GetDB().Exec(fmt.Sprintf("load data local infile %q into table t", path)) + require.NoError(t, err) + rows = dbt.MustQuery("show warnings") + cli.checkRows(t, rows, "Warning 1062 Duplicate entry 'w-1' for key 'idx'") + rows = dbt.MustQuery("select * from t order by id") + cli.checkRows(t, rows, "w 1 1", "w 2 2", "e 5 5", "n 9 9") // Test load data meet no partition warning. - cli.prepareLoadDataFile(c, path, "w 3 3", "w 5 5", "e 8 8") - _, err = dbt.db.Exec(fmt.Sprintf("load data local infile %q into table t", path)) - c.Assert(err, IsNil) - rows = dbt.mustQuery("show warnings") - cli.checkRows(c, rows, "Warning 1526 Table has no partition for value from column_list") - cli.prepareLoadDataFile(c, path, "x 1 1", "w 1 1") - _, err = dbt.db.Exec(fmt.Sprintf("load data local infile %q into table t", path)) - c.Assert(err, IsNil) - rows = dbt.mustQuery("show warnings") - cli.checkRows(c, rows, + cli.prepareLoadDataFile(t, path, "w 3 3", "w 5 5", "e 8 8") + _, err = dbt.GetDB().Exec(fmt.Sprintf("load data local infile %q into table t", path)) + require.NoError(t, err) + rows = dbt.MustQuery("show warnings") + cli.checkRows(t, rows, "Warning 1526 Table has no partition for value from column_list") + cli.prepareLoadDataFile(t, path, "x 1 1", "w 1 1") + _, err = dbt.GetDB().Exec(fmt.Sprintf("load data local infile %q into table t", path)) + require.NoError(t, err) + rows = dbt.MustQuery("show warnings") + cli.checkRows(t, rows, "Warning 1526 Table has no partition for value from column_list", "Warning 1062 Duplicate entry 'w-1' for key 'idx'") - rows = dbt.mustQuery("select * from t order by id") - cli.checkRows(c, rows, "w 1 1", "w 2 2", "w 3 3", "e 5 5", "e 8 8", "n 9 9") + rows = dbt.MustQuery("select * from t order by id") + cli.checkRows(t, rows, "w 1 1", "w 2 2", "w 3 3", "e 5 5", "e 8 8", "n 9 9") }) } diff --git a/server/tidb_serial_test.go b/server/tidb_serial_test.go index aa82e6f1b70d9..6dcb5a7b0a645 100644 --- a/server/tidb_serial_test.go +++ b/server/tidb_serial_test.go @@ -70,3 +70,13 @@ func TestStmtCount(t *testing.T) { ts.runTestStmtCount(t) } + +func TestLoadDataListPartition(t *testing.T) { + ts, cleanup := createTiDBTest(t) + defer cleanup() + + ts.runTestLoadDataForListPartition(t) + ts.runTestLoadDataForListPartition2(t) + ts.runTestLoadDataForListColumnPartition(t) + ts.runTestLoadDataForListColumnPartition2(t) +} diff --git a/server/tidb_test.go b/server/tidb_test.go index 1d6c3279dd52d..6afd937a3cbb6 100644 --- a/server/tidb_test.go +++ b/server/tidb_test.go @@ -268,13 +268,6 @@ func TestPreparedTimestamp(t *testing.T) { ts.runTestPreparedTimestamp(t) } -func (ts *tidbTestSerialSuite) TestLoadDataListPartition(c *C) { - ts.runTestLoadDataForListPartition(c) - ts.runTestLoadDataForListPartition2(c) - ts.runTestLoadDataForListColumnPartition(c) - ts.runTestLoadDataForListColumnPartition2(c) -} - func TestConcurrentUpdate(t *testing.T) { t.Parallel() ts, cleanup := createTiDBTest(t) From b04dcb56da752eceb485e9f585dbfea422fec702 Mon Sep 17 00:00:00 2001 From: yedamo Date: Sun, 7 Nov 2021 18:41:33 +0800 Subject: [PATCH 23/55] server: migrate `TestRegression` --- server/server_test.go | 78 +++++++++++++++++++++++++++++++++++++++++++ server/tidb_test.go | 8 +++-- 2 files changed, 83 insertions(+), 3 deletions(-) diff --git a/server/server_test.go b/server/server_test.go index 2e9592bee8d80..4f662d07df596 100644 --- a/server/server_test.go +++ b/server/server_test.go @@ -345,6 +345,84 @@ func (cli *testServerClient) runTestRegression(c *C, overrider configOverrider, }) } +func (cli *testingServerClient) runTestRegression(t *testing.T, overrider configOverrider, dbName string) { + cli.runTestsOnNewDB(t, overrider, dbName, func(dbt *testkit.DBTestKit) { + // Show the user + dbt.MustExec("select user()") + + // Create Table + dbt.MustExec("CREATE TABLE test (val TINYINT)") + + // Test for unexpected data + var out bool + rows := dbt.MustQuery("SELECT * FROM test") + require.Falsef(t, rows.Next(), "unexpected data in empty table") + + // Create Data + res := dbt.MustExec("INSERT INTO test VALUES (1)") + // res := dbt.mustExec("INSERT INTO test VALUES (?)", 1) + count, err := res.RowsAffected() + require.NoError(t, err) + require.Equal(t, int64(1), count) + id, err := res.LastInsertId() + require.NoError(t, err) + require.Equal(t, int64(0), id) + + // Read + rows = dbt.MustQuery("SELECT val FROM test") + if rows.Next() { + err = rows.Scan(&out) + require.NoError(t, err) + require.True(t, out) + require.Falsef(t, rows.Next(), "unexpected data") + } else { + require.Fail(t, "no data") + } + rows.Close() + + // Update + res = dbt.MustExec("UPDATE test SET val = 0 WHERE val = ?", 1) + count, err = res.RowsAffected() + require.NoError(t, err) + require.Equal(t, int64(1), count) + + // Check Update + rows = dbt.MustQuery("SELECT val FROM test") + if rows.Next() { + err = rows.Scan(&out) + require.NoError(t, err) + require.False(t, out) + require.Falsef(t, rows.Next(), "unexpected data") + } else { + require.Fail(t, "no data") + } + rows.Close() + + // Delete + res = dbt.MustExec("DELETE FROM test WHERE val = 0") + // res = dbt.mustExec("DELETE FROM test WHERE val = ?", 0) + count, err = res.RowsAffected() + require.NoError(t, err) + require.Equal(t, int64(1), count) + + // Check for unexpected rows + res = dbt.MustExec("DELETE FROM test") + count, err = res.RowsAffected() + require.NoError(t, err) + require.Equal(t, int64(0), count) + + dbt.MustQueryRows("SELECT 1") + + var b = make([]byte, 0) + if err := dbt.GetDB().QueryRow("SELECT ?", b).Scan(&b); err != nil { + t.Fatal(err) + } + if b == nil { + require.Fail(t, "nil echo from non-nil input") + } + }) +} + func (cli *testingServerClient) runTestPrepareResultFieldType(t *testing.T) { var param int64 = 83 cli.runTests(t, nil, func(dbt *testkit.DBTestKit) { diff --git a/server/tidb_test.go b/server/tidb_test.go index 6afd937a3cbb6..c361fe32d67fc 100644 --- a/server/tidb_test.go +++ b/server/tidb_test.go @@ -229,10 +229,12 @@ func (ts *tidbTestSuiteBase) TearDownSuite(c *C) { } } -func (ts *tidbTestSuite) TestRegression(c *C) { +func TestRegression(t *testing.T) { + ts, cleanup := createTiDBTest(t) + defer cleanup() if regression { - c.Parallel() - ts.runTestRegression(c, nil, "Regression") + t.Parallel() + ts.runTestRegression(t, nil, "Regression") } } From f55e8194b7f087471c44a4488eab3d15526e3ed6 Mon Sep 17 00:00:00 2001 From: yedamo Date: Sun, 7 Nov 2021 18:47:45 +0800 Subject: [PATCH 24/55] server: migrate `TestIssues` --- server/server_test.go | 26 +++++++++++++------------- server/tidb_test.go | 13 ++++++++----- 2 files changed, 21 insertions(+), 18 deletions(-) diff --git a/server/server_test.go b/server/server_test.go index 4f662d07df596..043eacfa45537 100644 --- a/server/server_test.go +++ b/server/server_test.go @@ -1764,47 +1764,47 @@ func (cli *testingServerClient) runTestAuth(t *testing.T) { }) } -func (cli *testServerClient) runTestIssue3662(c *C) { +func (cli *testingServerClient) runTestIssue3662(t *testing.T) { db, err := sql.Open("mysql", cli.getDSN(func(config *mysql.Config) { config.DBName = "non_existing_schema" })) - c.Assert(err, IsNil) + require.NoError(t, err) defer func() { err := db.Close() - c.Assert(err, IsNil) + require.NoError(t, err) }() // According to documentation, "Open may just validate its arguments without // creating a connection to the database. To verify that the data source name // is valid, call Ping." err = db.Ping() - c.Assert(err, NotNil) - c.Assert(err.Error(), Equals, "Error 1049: Unknown database 'non_existing_schema'") + require.Error(t, err) + require.Equal(t, "Error 1049: Unknown database 'non_existing_schema'", err.Error()) } -func (cli *testServerClient) runTestIssue3680(c *C) { +func (cli *testingServerClient) runTestIssue3680(t *testing.T) { db, err := sql.Open("mysql", cli.getDSN(func(config *mysql.Config) { config.User = "non_existing_user" })) - c.Assert(err, IsNil) + require.NoError(t, err) defer func() { err := db.Close() - c.Assert(err, IsNil) + require.NoError(t, err) }() // According to documentation, "Open may just validate its arguments without // creating a connection to the database. To verify that the data source name // is valid, call Ping." err = db.Ping() - c.Assert(err, NotNil) - c.Assert(err.Error(), Equals, "Error 1045: Access denied for user 'non_existing_user'@'127.0.0.1' (using password: NO)") + require.Error(t, err) + require.Equal(t, "Error 1045: Access denied for user 'non_existing_user'@'127.0.0.1' (using password: NO)", err.Error()) } -func (cli *testServerClient) runTestIssue22646(c *C) { - cli.runTests(c, nil, func(dbt *DBTest) { +func (cli *testingServerClient) runTestIssue22646(t *testing.T) { + cli.runTests(t, nil, func(dbt *testkit.DBTestKit) { c1 := make(chan string, 1) go func() { - dbt.mustExec(``) // empty query. + dbt.MustExec(``) // empty query. c1 <- "success" }() select { diff --git a/server/tidb_test.go b/server/tidb_test.go index c361fe32d67fc..a02de2aa70ff6 100644 --- a/server/tidb_test.go +++ b/server/tidb_test.go @@ -295,11 +295,14 @@ func TestAuth(t *testing.T) { ts.runTestIssue3682(t) } -func (ts *tidbTestSuite) TestIssues(c *C) { - c.Parallel() - ts.runTestIssue3662(c) - ts.runTestIssue3680(c) - ts.runTestIssue22646(c) +func TestIssues(t *testing.T) { + t.Parallel() + ts, cleanup := createTiDBTest(t) + defer cleanup() + + ts.runTestIssue3662(t) + ts.runTestIssue3680(t) + ts.runTestIssue22646(t) } func (ts *tidbTestSuite) TestDBNameEscape(c *C) { From 923cf68be458302cc7fe530ddeb3f79cce990669 Mon Sep 17 00:00:00 2001 From: yedamo Date: Sun, 7 Nov 2021 18:51:19 +0800 Subject: [PATCH 25/55] server: migrate `TestDBNameEscape` --- server/server_test.go | 14 +++++++------- server/tidb_test.go | 8 +++++--- 2 files changed, 12 insertions(+), 10 deletions(-) diff --git a/server/server_test.go b/server/server_test.go index 043eacfa45537..cfd1e81de5018 100644 --- a/server/server_test.go +++ b/server/server_test.go @@ -1843,15 +1843,15 @@ func (cli *testingServerClient) runTestIssue3682(t *testing.T) { require.Equal(t, "Error 1045: Access denied for user 'issue3682'@'127.0.0.1' (using password: YES)", err.Error()) } -func (cli *testServerClient) runTestDBNameEscape(c *C) { - cli.runTests(c, nil, func(dbt *DBTest) { - dbt.mustExec("CREATE DATABASE `aa-a`;") +func (cli *testingServerClient) runTestDBNameEscape(t *testing.T) { + cli.runTests(t, nil, func(dbt *testkit.DBTestKit) { + dbt.MustExec("CREATE DATABASE `aa-a`;") }) - cli.runTests(c, func(config *mysql.Config) { + cli.runTests(t, func(config *mysql.Config) { config.DBName = "aa-a" - }, func(dbt *DBTest) { - dbt.mustExec(`USE mysql;`) - dbt.mustExec("DROP DATABASE `aa-a`") + }, func(dbt *testkit.DBTestKit) { + dbt.MustExec(`USE mysql;`) + dbt.MustExec("DROP DATABASE `aa-a`") }) } diff --git a/server/tidb_test.go b/server/tidb_test.go index a02de2aa70ff6..e9d6aa3afa633 100644 --- a/server/tidb_test.go +++ b/server/tidb_test.go @@ -305,9 +305,11 @@ func TestIssues(t *testing.T) { ts.runTestIssue22646(t) } -func (ts *tidbTestSuite) TestDBNameEscape(c *C) { - c.Parallel() - ts.runTestDBNameEscape(c) +func TestDBNameEscape(t *testing.T) { + t.Parallel() + ts, cleanup := createTiDBTest(t) + defer cleanup() + ts.runTestDBNameEscape(t) } func (ts *tidbTestSuite) TestResultFieldTableIsNull(c *C) { From 94ea04a1d18b08f655c03c5286780c6707ed7a03 Mon Sep 17 00:00:00 2001 From: yedamo Date: Sun, 7 Nov 2021 18:54:42 +0800 Subject: [PATCH 26/55] server: migrate `TestResultFieldTableIsNull` --- server/server_test.go | 12 ++++++------ server/tidb_test.go | 9 ++++++--- 2 files changed, 12 insertions(+), 9 deletions(-) diff --git a/server/server_test.go b/server/server_test.go index cfd1e81de5018..55291244058f9 100644 --- a/server/server_test.go +++ b/server/server_test.go @@ -1855,13 +1855,13 @@ func (cli *testingServerClient) runTestDBNameEscape(t *testing.T) { }) } -func (cli *testServerClient) runTestResultFieldTableIsNull(c *C) { - cli.runTestsOnNewDB(c, func(config *mysql.Config) { +func (cli *testingServerClient) runTestResultFieldTableIsNull(t *testing.T) { + cli.runTestsOnNewDB(t, func(config *mysql.Config) { config.Params["sql_mode"] = "''" - }, "ResultFieldTableIsNull", func(dbt *DBTest) { - dbt.mustExec("drop table if exists test;") - dbt.mustExec("create table test (c int);") - dbt.mustExec("explain select * from test;") + }, "ResultFieldTableIsNull", func(dbt *testkit.DBTestKit) { + dbt.MustExec("drop table if exists test;") + dbt.MustExec("create table test (c int);") + dbt.MustExec("explain select * from test;") }) } diff --git a/server/tidb_test.go b/server/tidb_test.go index e9d6aa3afa633..d73cb749bf23b 100644 --- a/server/tidb_test.go +++ b/server/tidb_test.go @@ -312,9 +312,12 @@ func TestDBNameEscape(t *testing.T) { ts.runTestDBNameEscape(t) } -func (ts *tidbTestSuite) TestResultFieldTableIsNull(c *C) { - c.Parallel() - ts.runTestResultFieldTableIsNull(c) +func TestResultFieldTableIsNull(t *testing.T) { + t.Parallel() + ts, cleanup := createTiDBTest(t) + defer cleanup() + + ts.runTestResultFieldTableIsNull(t) } func (ts *tidbTestSuite) TestStatusAPI(c *C) { From c8fc3df9a12c9ac6ac3f3d4f1e61b6c6042fd8d6 Mon Sep 17 00:00:00 2001 From: yedamo Date: Sun, 7 Nov 2021 18:57:44 +0800 Subject: [PATCH 27/55] server: migrate `TestStatusAPI` --- server/tidb_test.go | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/server/tidb_test.go b/server/tidb_test.go index d73cb749bf23b..51661fb01e5f0 100644 --- a/server/tidb_test.go +++ b/server/tidb_test.go @@ -320,9 +320,12 @@ func TestResultFieldTableIsNull(t *testing.T) { ts.runTestResultFieldTableIsNull(t) } -func (ts *tidbTestSuite) TestStatusAPI(c *C) { - c.Parallel() - ts.runTestStatusAPI(c) +func TestStatusAPI(t *testing.T) { + t.Parallel() + ts, cleanup := createTiDBTest(t) + defer cleanup() + + ts.runTestStatusAPI(t) } func TestStatusPort(t *testing.T) { From 3e591efe4add044ba77e8e1bba2ac78e8183b11c Mon Sep 17 00:00:00 2001 From: yedamo Date: Sun, 7 Nov 2021 19:20:46 +0800 Subject: [PATCH 28/55] server: migrate `TestMultiStatements` --- server/server_test.go | 115 +++++++++++++++++++++--------------------- server/tidb_test.go | 11 ++-- 2 files changed, 64 insertions(+), 62 deletions(-) diff --git a/server/server_test.go b/server/server_test.go index 55291244058f9..8af4774d44ee4 100644 --- a/server/server_test.go +++ b/server/server_test.go @@ -1893,122 +1893,121 @@ func (cli *testingServerClient) runTestStatusAPI(t *testing.T) { // disabled by default for security reasons. Lets ensure that the behavior // is correct. -func (cli *testServerClient) runFailedTestMultiStatements(c *C) { - cli.runTestsOnNewDB(c, nil, "FailedMultiStatements", func(dbt *DBTest) { +func (cli *testingServerClient) runFailedTestMultiStatements(t *testing.T) { + cli.runTestsOnNewDB(t, nil, "FailedMultiStatements", func(dbt *testkit.DBTestKit) { // Default is now OFF in new installations. // It is still WARN in upgrade installations (for now) - _, err := dbt.db.Exec("SELECT 1; SELECT 1; SELECT 2; SELECT 3;") - c.Assert(err.Error(), Equals, "Error 8130: client has multi-statement capability disabled. Run SET GLOBAL tidb_multi_statement_mode='ON' after you understand the security risk") + _, err := dbt.GetDB().Exec("SELECT 1; SELECT 1; SELECT 2; SELECT 3;") + require.Equal(t, "Error 8130: client has multi-statement capability disabled. Run SET GLOBAL tidb_multi_statement_mode='ON' after you understand the security risk", err.Error()) // Change to WARN (legacy mode) - dbt.mustExec("SET tidb_multi_statement_mode='WARN'") - dbt.mustExec("CREATE TABLE `test` (`id` int(11) NOT NULL, `value` int(11) NOT NULL) ") - res := dbt.mustExec("INSERT INTO test VALUES (1, 1)") + dbt.MustExec("SET tidb_multi_statement_mode='WARN'") + dbt.MustExec("CREATE TABLE `test` (`id` int(11) NOT NULL, `value` int(11) NOT NULL) ") + res := dbt.MustExec("INSERT INTO test VALUES (1, 1)") count, err := res.RowsAffected() - c.Assert(err, IsNil, Commentf("res.RowsAffected() returned error")) - c.Assert(count, Equals, int64(1)) - res = dbt.mustExec("UPDATE test SET value = 3 WHERE id = 1; UPDATE test SET value = 4 WHERE id = 1; UPDATE test SET value = 5 WHERE id = 1;") + require.NoErrorf(t, err, "res.RowsAffected() returned error") + require.Equal(t, int64(1), count) + res = dbt.MustExec("UPDATE test SET value = 3 WHERE id = 1; UPDATE test SET value = 4 WHERE id = 1; UPDATE test SET value = 5 WHERE id = 1;") count, err = res.RowsAffected() - c.Assert(err, IsNil, Commentf("res.RowsAffected() returned error")) - c.Assert(count, Equals, int64(1)) - rows := dbt.mustQuery("show warnings") - cli.checkRows(c, rows, "Warning 8130 client has multi-statement capability disabled. Run SET GLOBAL tidb_multi_statement_mode='ON' after you understand the security risk") + require.NoErrorf(t, err, "res.RowsAffected() returned error") + require.Equal(t, int64(1), count) + rows := dbt.MustQuery("show warnings") + cli.checkRows(t, rows, "Warning 8130 client has multi-statement capability disabled. Run SET GLOBAL tidb_multi_statement_mode='ON' after you understand the security risk") var out int - rows = dbt.mustQuery("SELECT value FROM test WHERE id=1;") + rows = dbt.MustQuery("SELECT value FROM test WHERE id=1;") if rows.Next() { err = rows.Scan(&out) - c.Assert(err, IsNil) - c.Assert(out, Equals, 5) + require.NoError(t, err) + require.Equal(t, 5, out) if rows.Next() { - dbt.Error("unexpected data") + require.Fail(t, "unexpected data") } } else { - dbt.Error("no data") + require.Fail(t, "no data") } // Change to ON = Fully supported, TiDB legacy. No warnings or Errors. - dbt.mustExec("SET tidb_multi_statement_mode='ON';") - dbt.mustExec("DROP TABLE IF EXISTS test") - dbt.mustExec("CREATE TABLE `test` (`id` int(11) NOT NULL, `value` int(11) NOT NULL) ") - res = dbt.mustExec("INSERT INTO test VALUES (1, 1)") + dbt.MustExec("SET tidb_multi_statement_mode='ON';") + dbt.MustExec("DROP TABLE IF EXISTS test") + dbt.MustExec("CREATE TABLE `test` (`id` int(11) NOT NULL, `value` int(11) NOT NULL) ") + res = dbt.MustExec("INSERT INTO test VALUES (1, 1)") count, err = res.RowsAffected() - c.Assert(err, IsNil, Commentf("res.RowsAffected() returned error")) - c.Assert(count, Equals, int64(1)) - res = dbt.mustExec("update test SET value = 3 WHERE id = 1; UPDATE test SET value = 4 WHERE id = 1; UPDATE test SET value = 5 WHERE id = 1;") + require.NoErrorf(t, err, "res.RowsAffected() returned error") + require.Equal(t, int64(1), count) + res = dbt.MustExec("update test SET value = 3 WHERE id = 1; UPDATE test SET value = 4 WHERE id = 1; UPDATE test SET value = 5 WHERE id = 1;") count, err = res.RowsAffected() - c.Assert(err, IsNil, Commentf("res.RowsAffected() returned error")) - c.Assert(count, Equals, int64(1)) - rows = dbt.mustQuery("SELECT value FROM test WHERE id=1;") + require.NoErrorf(t, err, "res.RowsAffected() returned error") + require.Equal(t, int64(1), count) + rows = dbt.MustQuery("SELECT value FROM test WHERE id=1;") if rows.Next() { err = rows.Scan(&out) - c.Assert(err, IsNil) - c.Assert(out, Equals, 5) + require.NoError(t, err) + require.Equal(t, 5, out) if rows.Next() { - dbt.Error("unexpected data") + require.Fail(t, "unexpected data") } } else { - dbt.Error("no data") + require.Fail(t, "no data") } - }) } -func (cli *testServerClient) runTestMultiStatements(c *C) { +func (cli *testingServerClient) runTestMultiStatements(t *testing.T) { - cli.runTestsOnNewDB(c, func(config *mysql.Config) { + cli.runTestsOnNewDB(t, func(config *mysql.Config) { config.Params["multiStatements"] = "true" - }, "MultiStatements", func(dbt *DBTest) { + }, "MultiStatements", func(dbt *testkit.DBTestKit) { // Create Table - dbt.mustExec("CREATE TABLE `test` (`id` int(11) NOT NULL, `value` int(11) NOT NULL) ") + dbt.MustExec("CREATE TABLE `test` (`id` int(11) NOT NULL, `value` int(11) NOT NULL) ") // Create Data - res := dbt.mustExec("INSERT INTO test VALUES (1, 1)") + res := dbt.MustExec("INSERT INTO test VALUES (1, 1)") count, err := res.RowsAffected() - c.Assert(err, IsNil, Commentf("res.RowsAffected() returned error")) - c.Assert(count, Equals, int64(1)) + require.NoErrorf(t, err, "res.RowsAffected() returned error") + require.Equal(t, int64(1), count) // Update - res = dbt.mustExec("UPDATE test SET value = 3 WHERE id = 1; UPDATE test SET value = 4 WHERE id = 1; UPDATE test SET value = 5 WHERE id = 1;") + res = dbt.MustExec("UPDATE test SET value = 3 WHERE id = 1; UPDATE test SET value = 4 WHERE id = 1; UPDATE test SET value = 5 WHERE id = 1;") count, err = res.RowsAffected() - c.Assert(err, IsNil, Commentf("res.RowsAffected() returned error")) - c.Assert(count, Equals, int64(1)) + require.NoErrorf(t, err, "res.RowsAffected() returned error") + require.Equal(t, int64(1), count) // Read var out int - rows := dbt.mustQuery("SELECT value FROM test WHERE id=1;") + rows := dbt.MustQuery("SELECT value FROM test WHERE id=1;") if rows.Next() { err = rows.Scan(&out) - c.Assert(err, IsNil) - c.Assert(out, Equals, 5) + require.NoError(t, err) + require.Equal(t, 5, out) if rows.Next() { - dbt.Error("unexpected data") + require.Fail(t, "unexpected data") } } else { - dbt.Error("no data") + require.Fail(t, "no data") } // Test issue #26688 // First we "reset" the CurrentDB by using a database and then dropping it. - dbt.mustExec("CREATE DATABASE dropme") - dbt.mustExec("USE dropme") - dbt.mustExec("DROP DATABASE dropme") + dbt.MustExec("CREATE DATABASE dropme") + dbt.MustExec("USE dropme") + dbt.MustExec("DROP DATABASE dropme") var usedb string - rows = dbt.mustQuery("SELECT IFNULL(DATABASE(),'success')") + rows = dbt.MustQuery("SELECT IFNULL(DATABASE(),'success')") if rows.Next() { err = rows.Scan(&usedb) - c.Assert(err, IsNil) - c.Assert(usedb, Equals, "success") + require.NoError(t, err) + require.Equal(t, "success", usedb) } else { - dbt.Error("no database() result") + require.Fail(t, "no database() result") } // Because no DB is selected, if the use multistmtuse is not successful, then // the create table + drop table statements will return errors. - dbt.mustExec("CREATE DATABASE multistmtuse") - dbt.mustExec("use multistmtuse; create table if not exists t1 (id int); drop table t1;") + dbt.MustExec("CREATE DATABASE multistmtuse") + dbt.MustExec("use multistmtuse; create table if not exists t1 (id int); drop table t1;") }) } diff --git a/server/tidb_test.go b/server/tidb_test.go index 51661fb01e5f0..0018bec301a75 100644 --- a/server/tidb_test.go +++ b/server/tidb_test.go @@ -472,10 +472,13 @@ func newTLSHttpClient(t *testing.T, caFile, certFile, keyFile string) *http.Clie return &http.Client{Transport: &http.Transport{TLSClientConfig: tlsConfig}} } -func (ts *tidbTestSuite) TestMultiStatements(c *C) { - c.Parallel() - ts.runFailedTestMultiStatements(c) - ts.runTestMultiStatements(c) +func TestMultiStatements(t *testing.T) { + t.Parallel() + ts, cleanup := createTiDBTest(t) + defer cleanup() + + ts.runFailedTestMultiStatements(t) + ts.runTestMultiStatements(t) } func (ts *tidbTestSuite) TestSocketForwarding(c *C) { From e61b1c283adee2c51ae7ea48997a85ab5a1deb47 Mon Sep 17 00:00:00 2001 From: yedamo Date: Sun, 7 Nov 2021 19:26:46 +0800 Subject: [PATCH 29/55] server: migrate `TestSocketForwarding` --- server/tidb_test.go | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/server/tidb_test.go b/server/tidb_test.go index 0018bec301a75..e2450623a05da 100644 --- a/server/tidb_test.go +++ b/server/tidb_test.go @@ -481,8 +481,12 @@ func TestMultiStatements(t *testing.T) { ts.runTestMultiStatements(t) } -func (ts *tidbTestSuite) TestSocketForwarding(c *C) { - cli := newTestServerClient() +func TestSocketForwarding(t *testing.T) { + t.Parallel() + ts, cleanup := createTiDBTest(t) + defer cleanup() + + cli := newTestingServerClient() cfg := newTestConfig() cfg.Socket = "/tmp/tidbtest.sock" cfg.Port = cli.port @@ -490,16 +494,16 @@ func (ts *tidbTestSuite) TestSocketForwarding(c *C) { cfg.Status.ReportStatus = false server, err := NewServer(cfg, ts.tidbdrv) - c.Assert(err, IsNil) + require.NoError(t, err) cli.port = getPortFromTCPAddr(server.listener.Addr()) go func() { err := server.Run() - c.Assert(err, IsNil) + require.NoError(t, err) }() time.Sleep(time.Millisecond * 100) defer server.Close() - cli.runTestRegression(c, func(config *mysql.Config) { + cli.runTestRegression(t, func(config *mysql.Config) { config.User = "root" config.Net = "unix" config.Addr = "/tmp/tidbtest.sock" From ddf36857fac1d5e5badec89577df3344d10aa2fb Mon Sep 17 00:00:00 2001 From: yedamo Date: Mon, 8 Nov 2021 20:56:58 +0800 Subject: [PATCH 30/55] testkit: Add comments for DBTestKit's methods --- testkit/dbtestkit.go | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/testkit/dbtestkit.go b/testkit/dbtestkit.go index f031c9b92bcc3..5e6ba74ec2fee 100644 --- a/testkit/dbtestkit.go +++ b/testkit/dbtestkit.go @@ -42,18 +42,23 @@ func NewDBTestKit(t *testing.T, db *sql.DB) *DBTestKit { } } +// MustPrepare creates a prepared statement for later queries or executions. func (tk *DBTestKit) MustPrepare(query string) *sql.Stmt { stmt, err := tk.db.Prepare(query) tk.require.NoErrorf(err, "Prepare %s", query) return stmt } +// MustExecPrepared executes a prepared statement with the given arguments and +// returns a Result summarizing the effect of the statement. func (tk *DBTestKit) MustExecPrepared(stmt *sql.Stmt, args ...interface{}) sql.Result { res, err := stmt.Exec(args...) tk.require.NoErrorf(err, "Execute prepared with args: %s", args) return res } +// MustQueryPrepared executes a prepared query statement with the given arguments +// and returns the query results as a *Rows. func (tk *DBTestKit) MustQueryPrepared(stmt *sql.Stmt, args ...interface{}) *sql.Rows { rows, err := stmt.Query(args...) tk.require.NoErrorf(err, "Query prepared with args: %s", args) @@ -78,6 +83,7 @@ func (tk *DBTestKit) MustQuery(sql string, args ...interface{}) *sql.Rows { return rows } +// MustQueryRows query the statements func (tk *DBTestKit) MustQueryRows(query string, args ...interface{}) { rows := tk.MustQuery(query, args...) tk.require.True(rows.Next()) From 83a661d1f737c3a2d1d133f35e5ba65423d38fa1 Mon Sep 17 00:00:00 2001 From: yedamo Date: Mon, 8 Nov 2021 21:01:30 +0800 Subject: [PATCH 31/55] server: migrate `TestGracefulShutdown` --- server/tidb_test.go | 30 +++++++++++++----------------- 1 file changed, 13 insertions(+), 17 deletions(-) diff --git a/server/tidb_test.go b/server/tidb_test.go index e2450623a05da..e4697c23c780e 100644 --- a/server/tidb_test.go +++ b/server/tidb_test.go @@ -1598,15 +1598,11 @@ func TestNO_DEFAULT_VALUEFlag(t *testing.T) { require.Equal(t, expectFlag, dumpFlag(cols[0].Type, cols[0].Flag)) } -func (ts *tidbTestSuite) TestGracefulShutdown(c *C) { - store, err := mockstore.NewMockStore() - c.Assert(err, IsNil) - defer store.Close() - session.DisableStats4Test() - dom, err := session.BootstrapSession(store) - c.Assert(err, IsNil) - defer dom.Close() - ts.tidbdrv = NewTiDBDriver(ts.store) +func TestGracefulShutdown(t *testing.T) { + t.Parallel() + ts, cleanup := createTiDBTest(t) + defer cleanup() + cli := newTestServerClient() cfg := newTestConfig() cfg.Socket = "" @@ -1616,32 +1612,32 @@ func (ts *tidbTestSuite) TestGracefulShutdown(c *C) { cfg.Status.ReportStatus = true cfg.Performance.TCPKeepAlive = true server, err := NewServer(cfg, ts.tidbdrv) - c.Assert(err, IsNil) - c.Assert(server, NotNil) + require.NoError(t, err) + require.NotNil(t, server) cli.port = getPortFromTCPAddr(server.listener.Addr()) cli.statusPort = getPortFromTCPAddr(server.statusListener.Addr()) go func() { err := server.Run() - c.Assert(err, IsNil) + require.NoError(t, err) }() time.Sleep(time.Millisecond * 100) resp, err := cli.fetchStatus("/status") // server is up - c.Assert(err, IsNil) - c.Assert(resp.Body.Close(), IsNil) + require.NoError(t, err) + require.Nil(t, resp.Body.Close()) go server.Close() time.Sleep(time.Millisecond * 500) resp, _ = cli.fetchStatus("/status") // should return 5xx code - c.Assert(resp.StatusCode, Equals, 500) - c.Assert(resp.Body.Close(), IsNil) + require.Equal(t, 500, resp.StatusCode) + require.Nil(t, resp.Body.Close()) time.Sleep(time.Second * 2) // nolint: bodyclose _, err = cli.fetchStatus("/status") // status is gone - c.Assert(err, ErrorMatches, ".*connect: connection refused") + require.Regexp(t, ".*connect: connection refused", err.Error()) } func (ts *tidbTestSerialSuite) TestDefaultCharacterAndCollation(c *C) { From 598076a9ac731f1fcfd7f86fc80689f4ecef982b Mon Sep 17 00:00:00 2001 From: yedamo Date: Mon, 8 Nov 2021 21:49:30 +0800 Subject: [PATCH 32/55] server: remove deprecated method --- server/server_test.go | 30 ------------------------------ testkit/dbtestkit.go | 1 + 2 files changed, 1 insertion(+), 30 deletions(-) diff --git a/server/server_test.go b/server/server_test.go index 8af4774d44ee4..b708b1cb8a34c 100644 --- a/server/server_test.go +++ b/server/server_test.go @@ -231,24 +231,6 @@ type DBTest struct { db *sql.DB } -func (dbt *DBTest) mustPrepare(query string) *sql.Stmt { - stmt, err := dbt.db.Prepare(query) - dbt.Assert(err, IsNil, Commentf("Prepare %s", query)) - return stmt -} - -func (dbt *DBTest) mustExecPrepared(stmt *sql.Stmt, args ...interface{}) sql.Result { - res, err := stmt.Exec(args...) - dbt.Assert(err, IsNil, Commentf("Execute prepared with args: %s", args)) - return res -} - -func (dbt *DBTest) mustQueryPrepared(stmt *sql.Stmt, args ...interface{}) *sql.Rows { - rows, err := stmt.Query(args...) - dbt.Assert(err, IsNil, Commentf("Query prepared with args: %s", args)) - return rows -} - func (dbt *DBTest) mustExec(query string, args ...interface{}) (res sql.Result) { res, err := dbt.db.Exec(query, args...) dbt.Assert(err, IsNil, Commentf("Exec %s", query)) @@ -1865,18 +1847,6 @@ func (cli *testingServerClient) runTestResultFieldTableIsNull(t *testing.T) { }) } -func (cli *testServerClient) runTestStatusAPI(c *C) { - resp, err := cli.fetchStatus("/status") - c.Assert(err, IsNil) - defer resp.Body.Close() - decoder := json.NewDecoder(resp.Body) - var data status - err = decoder.Decode(&data) - c.Assert(err, IsNil) - c.Assert(data.Version, Equals, tmysql.ServerVersion) - c.Assert(data.GitHash, Equals, versioninfo.TiDBGitHash) -} - func (cli *testingServerClient) runTestStatusAPI(t *testing.T) { resp, err := cli.fetchStatus("/status") require.NoError(t, err) diff --git a/testkit/dbtestkit.go b/testkit/dbtestkit.go index 5e6ba74ec2fee..5682597618683 100644 --- a/testkit/dbtestkit.go +++ b/testkit/dbtestkit.go @@ -87,6 +87,7 @@ func (tk *DBTestKit) MustQuery(sql string, args ...interface{}) *sql.Rows { func (tk *DBTestKit) MustQueryRows(query string, args ...interface{}) { rows := tk.MustQuery(query, args...) tk.require.True(rows.Next()) + tk.require.NoError(rows.Err()) rows.Close() } From 03c2a78cc091f09cb8eeeae5471a27b6a9e5a79b Mon Sep 17 00:00:00 2001 From: yedamo Date: Mon, 8 Nov 2021 22:04:26 +0800 Subject: [PATCH 33/55] server: migrate `TestTLSAuto` --- server/server_test.go | 15 +++++++++++++++ server/tidb_serial_test.go | 35 +++++++++++++++++++++++++++++++++++ server/tidb_test.go | 28 ---------------------------- 3 files changed, 50 insertions(+), 28 deletions(-) diff --git a/server/server_test.go b/server/server_test.go index b708b1cb8a34c..be6a2d64cf792 100644 --- a/server/server_test.go +++ b/server/server_test.go @@ -2031,6 +2031,21 @@ func (cli *testServerClient) runTestTLSConnection(t *C, overrider configOverride return err } +func (cli *testingServerClient) runTestTLSConnection(t *testing.T, overrider configOverrider) error { + dsn := cli.getDSN(overrider) + db, err := sql.Open("mysql", dsn) + require.NoError(t, err) + defer func() { + err := db.Close() + require.NoError(t, err) + }() + _, err = db.Exec("USE test") + if err != nil { + return errors.Annotate(err, "dsn:"+dsn) + } + return err +} + func (cli *testServerClient) runReloadTLS(t *C, overrider configOverrider, errorNoRollback bool) error { db, err := sql.Open("mysql", cli.getDSN(overrider)) t.Assert(err, IsNil) diff --git a/server/tidb_serial_test.go b/server/tidb_serial_test.go index 6dcb5a7b0a645..26b080c26de4b 100644 --- a/server/tidb_serial_test.go +++ b/server/tidb_serial_test.go @@ -15,9 +15,13 @@ package server import ( + "os" "testing" + "time" + "github.com/go-sql-driver/mysql" "github.com/pingcap/tidb/testkit" + "github.com/stretchr/testify/require" ) // this test will change `kv.TxnTotalSizeLimit` which may affect other test suites, @@ -80,3 +84,34 @@ func TestLoadDataListPartition(t *testing.T) { ts.runTestLoadDataForListColumnPartition(t) ts.runTestLoadDataForListColumnPartition2(t) } + +func TestTLSAuto(t *testing.T) { + ts, cleanup := createTiDBTest(t) + defer cleanup() + + // Start the server without TLS configure, letting the server create these as AutoTLS is enabled + connOverrider := func(config *mysql.Config) { + config.TLSConfig = "skip-verify" + } + cli := newTestingServerClient() + cfg := newTestConfig() + cfg.Socket = "" + cfg.Port = cli.port + cfg.Status.ReportStatus = false + cfg.Security.AutoTLS = true + cfg.Security.RSAKeySize = 528 // Reduces unittest runtime + err := os.MkdirAll(cfg.TempStoragePath, 0700) + require.NoError(t, err) + server, err := NewServer(cfg, ts.tidbdrv) + require.NoError(t, err) + cli.port = getPortFromTCPAddr(server.listener.Addr()) + go func() { + err := server.Run() + require.NoError(t, err) + }() + time.Sleep(time.Millisecond * 100) + err = cli.runTestTLSConnection(t, connOverrider) // Relying on automatically created TLS certificates + require.NoError(t, err) + + server.Close() +} diff --git a/server/tidb_test.go b/server/tidb_test.go index e4697c23c780e..7683bbce9ac0f 100644 --- a/server/tidb_test.go +++ b/server/tidb_test.go @@ -960,34 +960,6 @@ func TestSystemTimeZone(t *testing.T) { tk.MustQuery("select @@system_time_zone").Check(tz1) } -func (ts *tidbTestSerialSuite) TestTLSAuto(c *C) { - // Start the server without TLS configure, letting the server create these as AutoTLS is enabled - connOverrider := func(config *mysql.Config) { - config.TLSConfig = "skip-verify" - } - cli := newTestServerClient() - cfg := newTestConfig() - cfg.Socket = "" - cfg.Port = cli.port - cfg.Status.ReportStatus = false - cfg.Security.AutoTLS = true - cfg.Security.RSAKeySize = 528 // Reduces unittest runtime - err := os.MkdirAll(cfg.TempStoragePath, 0700) - c.Assert(err, IsNil) - server, err := NewServer(cfg, ts.tidbdrv) - c.Assert(err, IsNil) - cli.port = getPortFromTCPAddr(server.listener.Addr()) - go func() { - err := server.Run() - c.Assert(err, IsNil) - }() - time.Sleep(time.Millisecond * 100) - err = cli.runTestTLSConnection(c, connOverrider) // Relying on automatically created TLS certificates - c.Assert(err, IsNil) - - server.Close() -} - func (ts *tidbTestSerialSuite) TestTLSBasic(c *C) { // Generate valid TLS certificates. caCert, caKey, err := generateCert(0, "TiDB CA", nil, nil, "/tmp/ca-key.pem", "/tmp/ca-cert.pem") From 1599147f7011493b4834bea19e86bb46915f1048 Mon Sep 17 00:00:00 2001 From: yedamo Date: Mon, 8 Nov 2021 22:14:11 +0800 Subject: [PATCH 34/55] server: migrate `TestTLSBasic` --- server/tidb_serial_test.go | 78 ++++++++++++++++++++++++++++++++++++++ server/tidb_test.go | 70 ---------------------------------- 2 files changed, 78 insertions(+), 70 deletions(-) diff --git a/server/tidb_serial_test.go b/server/tidb_serial_test.go index 26b080c26de4b..95f43d06cb97a 100644 --- a/server/tidb_serial_test.go +++ b/server/tidb_serial_test.go @@ -20,6 +20,9 @@ import ( "time" "github.com/go-sql-driver/mysql" + "github.com/pingcap/errors" + "github.com/pingcap/tidb/config" + "github.com/pingcap/tidb/sessionctx/variable" "github.com/pingcap/tidb/testkit" "github.com/stretchr/testify/require" ) @@ -115,3 +118,78 @@ func TestTLSAuto(t *testing.T) { server.Close() } + +func TestTLSBasic(t *testing.T) { + ts, cleanup := createTiDBTest(t) + defer cleanup() + + // Generate valid TLS certificates. + caCert, caKey, err := generateCert(0, "TiDB CA", nil, nil, "/tmp/ca-key.pem", "/tmp/ca-cert.pem") + require.NoError(t, err) + serverCert, _, err := generateCert(1, "tidb-server", caCert, caKey, "/tmp/server-key.pem", "/tmp/server-cert.pem") + require.NoError(t, err) + _, _, err = generateCert(2, "SQL Client Certificate", caCert, caKey, "/tmp/client-key.pem", "/tmp/client-cert.pem") + require.NoError(t, err) + err = registerTLSConfig("client-certificate", "/tmp/ca-cert.pem", "/tmp/client-cert.pem", "/tmp/client-key.pem", "tidb-server", true) + require.NoError(t, err) + + defer func() { + err := os.Remove("/tmp/ca-key.pem") + require.NoError(t, err) + err = os.Remove("/tmp/ca-cert.pem") + require.NoError(t, err) + err = os.Remove("/tmp/server-key.pem") + require.NoError(t, err) + err = os.Remove("/tmp/server-cert.pem") + require.NoError(t, err) + err = os.Remove("/tmp/client-key.pem") + require.NoError(t, err) + err = os.Remove("/tmp/client-cert.pem") + require.NoError(t, err) + }() + + // Start the server with TLS but without CA, in this case the server will not verify client's certificate. + connOverrider := func(config *mysql.Config) { + config.TLSConfig = "skip-verify" + } + cli := newTestingServerClient() + cfg := newTestConfig() + cfg.Socket = "" + cfg.Port = cli.port + cfg.Status.ReportStatus = false + cfg.Security = config.Security{ + SSLCert: "/tmp/server-cert.pem", + SSLKey: "/tmp/server-key.pem", + } + server, err := NewServer(cfg, ts.tidbdrv) + require.NoError(t, err) + cli.port = getPortFromTCPAddr(server.listener.Addr()) + go func() { + err := server.Run() + require.NoError(t, err) + }() + time.Sleep(time.Millisecond * 100) + err = cli.runTestTLSConnection(t, connOverrider) // We should establish connection successfully. + require.NoError(t, err) + cli.runTestRegression(t, connOverrider, "TLSRegression") + // Perform server verification. + connOverrider = func(config *mysql.Config) { + config.TLSConfig = "client-certificate" + } + err = cli.runTestTLSConnection(t, connOverrider) // We should establish connection successfully. + require.NoError(t, err, "%v", errors.ErrorStack(err)) + cli.runTestRegression(t, connOverrider, "TLSRegression") + + // Test SSL/TLS session vars + var v *variable.SessionVars + stats, err := server.Stats(v) + require.NoError(t, err) + _, hasKey := stats["Ssl_server_not_after"] + require.True(t, hasKey) + _, hasKey = stats["Ssl_server_not_before"] + require.True(t, hasKey) + require.Equal(t, serverCert.NotAfter.Format("Jan _2 15:04:05 2006 MST"), stats["Ssl_server_not_after"]) + require.Equal(t, serverCert.NotBefore.Format("Jan _2 15:04:05 2006 MST"), stats["Ssl_server_not_before"]) + + server.Close() +} diff --git a/server/tidb_test.go b/server/tidb_test.go index 7683bbce9ac0f..cd07ac2660040 100644 --- a/server/tidb_test.go +++ b/server/tidb_test.go @@ -960,76 +960,6 @@ func TestSystemTimeZone(t *testing.T) { tk.MustQuery("select @@system_time_zone").Check(tz1) } -func (ts *tidbTestSerialSuite) TestTLSBasic(c *C) { - // Generate valid TLS certificates. - caCert, caKey, err := generateCert(0, "TiDB CA", nil, nil, "/tmp/ca-key.pem", "/tmp/ca-cert.pem") - c.Assert(err, IsNil) - serverCert, _, err := generateCert(1, "tidb-server", caCert, caKey, "/tmp/server-key.pem", "/tmp/server-cert.pem") - c.Assert(err, IsNil) - _, _, err = generateCert(2, "SQL Client Certificate", caCert, caKey, "/tmp/client-key.pem", "/tmp/client-cert.pem") - c.Assert(err, IsNil) - err = registerTLSConfig("client-certificate", "/tmp/ca-cert.pem", "/tmp/client-cert.pem", "/tmp/client-key.pem", "tidb-server", true) - c.Assert(err, IsNil) - - defer func() { - err := os.Remove("/tmp/ca-key.pem") - c.Assert(err, IsNil) - err = os.Remove("/tmp/ca-cert.pem") - c.Assert(err, IsNil) - err = os.Remove("/tmp/server-key.pem") - c.Assert(err, IsNil) - err = os.Remove("/tmp/server-cert.pem") - c.Assert(err, IsNil) - err = os.Remove("/tmp/client-key.pem") - c.Assert(err, IsNil) - err = os.Remove("/tmp/client-cert.pem") - c.Assert(err, IsNil) - }() - - // Start the server with TLS but without CA, in this case the server will not verify client's certificate. - connOverrider := func(config *mysql.Config) { - config.TLSConfig = "skip-verify" - } - cli := newTestServerClient() - cfg := newTestConfig() - cfg.Socket = "" - cfg.Port = cli.port - cfg.Status.ReportStatus = false - cfg.Security = config.Security{ - SSLCert: "/tmp/server-cert.pem", - SSLKey: "/tmp/server-key.pem", - } - server, err := NewServer(cfg, ts.tidbdrv) - c.Assert(err, IsNil) - cli.port = getPortFromTCPAddr(server.listener.Addr()) - go func() { - err := server.Run() - c.Assert(err, IsNil) - }() - time.Sleep(time.Millisecond * 100) - err = cli.runTestTLSConnection(c, connOverrider) // We should establish connection successfully. - c.Assert(err, IsNil) - cli.runTestRegression(c, connOverrider, "TLSRegression") - // Perform server verification. - connOverrider = func(config *mysql.Config) { - config.TLSConfig = "client-certificate" - } - err = cli.runTestTLSConnection(c, connOverrider) // We should establish connection successfully. - c.Assert(err, IsNil, Commentf("%v", errors.ErrorStack(err))) - cli.runTestRegression(c, connOverrider, "TLSRegression") - - // Test SSL/TLS session vars - var v *variable.SessionVars - stats, err := server.Stats(v) - c.Assert(err, IsNil) - c.Assert(stats, HasKey, "Ssl_server_not_after") - c.Assert(stats, HasKey, "Ssl_server_not_before") - c.Assert(stats["Ssl_server_not_after"], Equals, serverCert.NotAfter.Format("Jan _2 15:04:05 2006 MST")) - c.Assert(stats["Ssl_server_not_before"], Equals, serverCert.NotBefore.Format("Jan _2 15:04:05 2006 MST")) - - server.Close() -} - func (ts *tidbTestSerialSuite) TestTLSVerify(c *C) { // Generate valid TLS certificates. caCert, caKey, err := generateCert(0, "TiDB CA", nil, nil, "/tmp/ca-key.pem", "/tmp/ca-cert.pem") From 6476267a75e1fcd64b6677a920f04f602b0ca41c Mon Sep 17 00:00:00 2001 From: yedamo Date: Mon, 8 Nov 2021 22:23:02 +0800 Subject: [PATCH 35/55] server: migrate `TestTLSVerify` --- server/tidb_serial_test.go | 76 ++++++++++++++++++++++++++++++++++++++ server/tidb_test.go | 71 ----------------------------------- 2 files changed, 76 insertions(+), 71 deletions(-) diff --git a/server/tidb_serial_test.go b/server/tidb_serial_test.go index 95f43d06cb97a..5d88c655d5c5d 100644 --- a/server/tidb_serial_test.go +++ b/server/tidb_serial_test.go @@ -15,6 +15,7 @@ package server import ( + "crypto/x509" "os" "testing" "time" @@ -24,6 +25,7 @@ import ( "github.com/pingcap/tidb/config" "github.com/pingcap/tidb/sessionctx/variable" "github.com/pingcap/tidb/testkit" + "github.com/pingcap/tidb/util" "github.com/stretchr/testify/require" ) @@ -193,3 +195,77 @@ func TestTLSBasic(t *testing.T) { server.Close() } + +func TestTLSVerify(t *testing.T) { + ts, cleanup := createTiDBTest(t) + defer cleanup() + + // Generate valid TLS certificates. + caCert, caKey, err := generateCert(0, "TiDB CA", nil, nil, "/tmp/ca-key.pem", "/tmp/ca-cert.pem") + require.NoError(t, err) + _, _, err = generateCert(1, "tidb-server", caCert, caKey, "/tmp/server-key.pem", "/tmp/server-cert.pem") + require.NoError(t, err) + _, _, err = generateCert(2, "SQL Client Certificate", caCert, caKey, "/tmp/client-key.pem", "/tmp/client-cert.pem") + require.NoError(t, err) + err = registerTLSConfig("client-certificate", "/tmp/ca-cert.pem", "/tmp/client-cert.pem", "/tmp/client-key.pem", "tidb-server", true) + require.NoError(t, err) + + defer func() { + err := os.Remove("/tmp/ca-key.pem") + require.NoError(t, err) + err = os.Remove("/tmp/ca-cert.pem") + require.NoError(t, err) + err = os.Remove("/tmp/server-key.pem") + require.NoError(t, err) + err = os.Remove("/tmp/server-cert.pem") + require.NoError(t, err) + err = os.Remove("/tmp/client-key.pem") + require.NoError(t, err) + err = os.Remove("/tmp/client-cert.pem") + require.NoError(t, err) + }() + + // Start the server with TLS & CA, if the client presents its certificate, the certificate will be verified. + cli := newTestingServerClient() + cfg := newTestConfig() + cfg.Socket = "" + cfg.Port = cli.port + cfg.Status.ReportStatus = false + cfg.Security = config.Security{ + SSLCA: "/tmp/ca-cert.pem", + SSLCert: "/tmp/server-cert.pem", + SSLKey: "/tmp/server-key.pem", + } + server, err := NewServer(cfg, ts.tidbdrv) + require.NoError(t, err) + cli.port = getPortFromTCPAddr(server.listener.Addr()) + go func() { + err := server.Run() + require.NoError(t, err) + }() + time.Sleep(time.Millisecond * 100) + // The client does not provide a certificate, the connection should succeed. + err = cli.runTestTLSConnection(t, nil) + require.NoError(t, err) + connOverrider := func(config *mysql.Config) { + config.TLSConfig = "client-certificate" + } + cli.runTestRegression(t, connOverrider, "TLSRegression") + // The client provides a valid certificate. + connOverrider = func(config *mysql.Config) { + config.TLSConfig = "client-certificate" + } + err = cli.runTestTLSConnection(t, connOverrider) + require.NoError(t, err) + cli.runTestRegression(t, connOverrider, "TLSRegression") + server.Close() + + require.False(t, util.IsTLSExpiredError(errors.New("unknown test"))) + require.False(t, util.IsTLSExpiredError(x509.CertificateInvalidError{Reason: x509.CANotAuthorizedForThisName})) + require.True(t, util.IsTLSExpiredError(x509.CertificateInvalidError{Reason: x509.Expired})) + + _, _, err = util.LoadTLSCertificates("", "wrong key", "wrong cert", true, 528) + require.Error(t, err) + _, _, err = util.LoadTLSCertificates("wrong ca", "/tmp/server-key.pem", "/tmp/server-cert.pem", true, 528) + require.Error(t, err) +} diff --git a/server/tidb_test.go b/server/tidb_test.go index cd07ac2660040..691bd417059c1 100644 --- a/server/tidb_test.go +++ b/server/tidb_test.go @@ -960,77 +960,6 @@ func TestSystemTimeZone(t *testing.T) { tk.MustQuery("select @@system_time_zone").Check(tz1) } -func (ts *tidbTestSerialSuite) TestTLSVerify(c *C) { - // Generate valid TLS certificates. - caCert, caKey, err := generateCert(0, "TiDB CA", nil, nil, "/tmp/ca-key.pem", "/tmp/ca-cert.pem") - c.Assert(err, IsNil) - _, _, err = generateCert(1, "tidb-server", caCert, caKey, "/tmp/server-key.pem", "/tmp/server-cert.pem") - c.Assert(err, IsNil) - _, _, err = generateCert(2, "SQL Client Certificate", caCert, caKey, "/tmp/client-key.pem", "/tmp/client-cert.pem") - c.Assert(err, IsNil) - err = registerTLSConfig("client-certificate", "/tmp/ca-cert.pem", "/tmp/client-cert.pem", "/tmp/client-key.pem", "tidb-server", true) - c.Assert(err, IsNil) - - defer func() { - err := os.Remove("/tmp/ca-key.pem") - c.Assert(err, IsNil) - err = os.Remove("/tmp/ca-cert.pem") - c.Assert(err, IsNil) - err = os.Remove("/tmp/server-key.pem") - c.Assert(err, IsNil) - err = os.Remove("/tmp/server-cert.pem") - c.Assert(err, IsNil) - err = os.Remove("/tmp/client-key.pem") - c.Assert(err, IsNil) - err = os.Remove("/tmp/client-cert.pem") - c.Assert(err, IsNil) - }() - - // Start the server with TLS & CA, if the client presents its certificate, the certificate will be verified. - cli := newTestServerClient() - cfg := newTestConfig() - cfg.Socket = "" - cfg.Port = cli.port - cfg.Status.ReportStatus = false - cfg.Security = config.Security{ - SSLCA: "/tmp/ca-cert.pem", - SSLCert: "/tmp/server-cert.pem", - SSLKey: "/tmp/server-key.pem", - } - server, err := NewServer(cfg, ts.tidbdrv) - c.Assert(err, IsNil) - cli.port = getPortFromTCPAddr(server.listener.Addr()) - go func() { - err := server.Run() - c.Assert(err, IsNil) - }() - time.Sleep(time.Millisecond * 100) - // The client does not provide a certificate, the connection should succeed. - err = cli.runTestTLSConnection(c, nil) - c.Assert(err, IsNil) - connOverrider := func(config *mysql.Config) { - config.TLSConfig = "client-certificate" - } - cli.runTestRegression(c, connOverrider, "TLSRegression") - // The client provides a valid certificate. - connOverrider = func(config *mysql.Config) { - config.TLSConfig = "client-certificate" - } - err = cli.runTestTLSConnection(c, connOverrider) - c.Assert(err, IsNil) - cli.runTestRegression(c, connOverrider, "TLSRegression") - server.Close() - - c.Assert(util.IsTLSExpiredError(errors.New("unknown test")), IsFalse) - c.Assert(util.IsTLSExpiredError(x509.CertificateInvalidError{Reason: x509.CANotAuthorizedForThisName}), IsFalse) - c.Assert(util.IsTLSExpiredError(x509.CertificateInvalidError{Reason: x509.Expired}), IsTrue) - - _, _, err = util.LoadTLSCertificates("", "wrong key", "wrong cert", true, 528) - c.Assert(err, NotNil) - _, _, err = util.LoadTLSCertificates("wrong ca", "/tmp/server-key.pem", "/tmp/server-cert.pem", true, 528) - c.Assert(err, NotNil) -} - func (ts *tidbTestSerialSuite) TestReloadTLS(c *C) { // Generate valid TLS certificates. caCert, caKey, err := generateCert(0, "TiDB CA", nil, nil, "/tmp/ca-key-reload.pem", "/tmp/ca-cert-reload.pem") From 227ab9783bcd5e55ad8fe24579b70dde28de9cd6 Mon Sep 17 00:00:00 2001 From: yedamo Date: Tue, 9 Nov 2021 20:29:14 +0800 Subject: [PATCH 36/55] server: migrate `TestClientWithCollation` --- server/server_test.go | 37 +++++++++++++++++++------------------ server/tidb_test.go | 9 ++++++--- 2 files changed, 25 insertions(+), 21 deletions(-) diff --git a/server/server_test.go b/server/server_test.go index be6a2d64cf792..7fa0da86776c8 100644 --- a/server/server_test.go +++ b/server/server_test.go @@ -443,38 +443,39 @@ func (cli *testingServerClient) runTestSpecialType(t *testing.T) { }) } -func (cli *testServerClient) runTestClientWithCollation(t *C) { +func (cli *testingServerClient) runTestClientWithCollation(t *testing.T) { cli.runTests(t, func(config *mysql.Config) { config.Collation = "utf8mb4_general_ci" - }, func(dbt *DBTest) { + }, func(dbt *testkit.DBTestKit) { var name, charset, collation string // check session variable collation_connection - rows := dbt.mustQuery("show variables like 'collation_connection'") - t.Assert(rows.Next(), IsTrue) + rows := dbt.MustQuery("show variables like 'collation_connection'") + require.True(t, rows.Next()) + err := rows.Scan(&name, &collation) - t.Assert(err, IsNil) - t.Assert(collation, Equals, "utf8mb4_general_ci") + require.NoError(t, err) + require.Equal(t, "utf8mb4_general_ci", collation) // check session variable character_set_client - rows = dbt.mustQuery("show variables like 'character_set_client'") - t.Assert(rows.Next(), IsTrue) + rows = dbt.MustQuery("show variables like 'character_set_client'") + require.True(t, rows.Next()) err = rows.Scan(&name, &charset) - t.Assert(err, IsNil) - t.Assert(charset, Equals, "utf8mb4") + require.NoError(t, err) + require.Equal(t, "utf8mb4", charset) // check session variable character_set_results - rows = dbt.mustQuery("show variables like 'character_set_results'") - t.Assert(rows.Next(), IsTrue) + rows = dbt.MustQuery("show variables like 'character_set_results'") + require.True(t, rows.Next()) err = rows.Scan(&name, &charset) - t.Assert(err, IsNil) - t.Assert(charset, Equals, "utf8mb4") + require.NoError(t, err) + require.Equal(t, "utf8mb4", charset) // check session variable character_set_connection - rows = dbt.mustQuery("show variables like 'character_set_connection'") - t.Assert(rows.Next(), IsTrue) + rows = dbt.MustQuery("show variables like 'character_set_connection'") + require.True(t, rows.Next()) err = rows.Scan(&name, &charset) - t.Assert(err, IsNil) - t.Assert(charset, Equals, "utf8mb4") + require.NoError(t, err) + require.Equal(t, "utf8mb4", charset) }) } diff --git a/server/tidb_test.go b/server/tidb_test.go index 691bd417059c1..e88b190899094 100644 --- a/server/tidb_test.go +++ b/server/tidb_test.go @@ -1129,9 +1129,12 @@ func (ts *tidbTestSerialSuite) TestErrorNoRollback(c *C) { c.Assert(tlsCfg, IsNil) } -func (ts *tidbTestSuite) TestClientWithCollation(c *C) { - c.Parallel() - ts.runTestClientWithCollation(c) +func TestClientWithCollation(t *testing.T) { + t.Parallel() + ts, cleanup := createTiDBTest(t) + defer cleanup() + + ts.runTestClientWithCollation(t) } func TestCreateTableFlen(t *testing.T) { From 793b93134b8b8785f3b98ef4c7a28e23515d616c Mon Sep 17 00:00:00 2001 From: yedamo Date: Tue, 9 Nov 2021 22:15:37 +0800 Subject: [PATCH 37/55] server: migrate `TestErrorNoRollback` --- server/server_test.go | 15 ++++++++ server/tidb_serial_test.go | 70 ++++++++++++++++++++++++++++++++++++++ server/tidb_test.go | 67 ------------------------------------ 3 files changed, 85 insertions(+), 67 deletions(-) diff --git a/server/server_test.go b/server/server_test.go index 7fa0da86776c8..6ed434bf704f1 100644 --- a/server/server_test.go +++ b/server/server_test.go @@ -2062,6 +2062,21 @@ func (cli *testServerClient) runReloadTLS(t *C, overrider configOverrider, error return err } +func (cli *testingServerClient) runReloadTLS(t *testing.T, overrider configOverrider, errorNoRollback bool) error { + db, err := sql.Open("mysql", cli.getDSN(overrider)) + require.NoError(t, err) + defer func() { + err := db.Close() + require.NoError(t, err) + }() + sql := "alter instance reload tls" + if errorNoRollback { + sql += " no rollback on error" + } + _, err = db.Exec(sql) + return err +} + func (cli *testServerClient) runTestSumAvg(c *C) { cli.runTests(c, nil, func(dbt *DBTest) { dbt.mustExec("create table sumavg (a int, b decimal, c double)") diff --git a/server/tidb_serial_test.go b/server/tidb_serial_test.go index 5d88c655d5c5d..0318cb939eb17 100644 --- a/server/tidb_serial_test.go +++ b/server/tidb_serial_test.go @@ -269,3 +269,73 @@ func TestTLSVerify(t *testing.T) { _, _, err = util.LoadTLSCertificates("wrong ca", "/tmp/server-key.pem", "/tmp/server-cert.pem", true, 528) require.Error(t, err) } + +func TestErrorNoRollback(t *testing.T) { + ts, cleanup := createTiDBTest(t) + defer cleanup() + + // Generate valid TLS certificates. + caCert, caKey, err := generateCert(0, "TiDB CA", nil, nil, "/tmp/ca-key-rollback.pem", "/tmp/ca-cert-rollback.pem") + require.NoError(t, err) + _, _, err = generateCert(1, "tidb-server", caCert, caKey, "/tmp/server-key-rollback.pem", "/tmp/server-cert-rollback.pem") + require.NoError(t, err) + _, _, err = generateCert(2, "SQL Client Certificate", caCert, caKey, "/tmp/client-key-rollback.pem", "/tmp/client-cert-rollback.pem") + require.NoError(t, err) + err = registerTLSConfig("client-cert-rollback-test", "/tmp/ca-cert-rollback.pem", "/tmp/client-cert-rollback.pem", "/tmp/client-key-rollback.pem", "tidb-server", true) + require.NoError(t, err) + + defer func() { + os.Remove("/tmp/ca-key-rollback.pem") + os.Remove("/tmp/ca-cert-rollback.pem") + + os.Remove("/tmp/server-key-rollback.pem") + os.Remove("/tmp/server-cert-rollback.pem") + os.Remove("/tmp/client-key-rollback.pem") + os.Remove("/tmp/client-cert-rollback.pem") + }() + + cli := newTestingServerClient() + cfg := newTestConfig() + cfg.Socket = "" + cfg.Port = cli.port + cfg.Status.ReportStatus = false + + cfg.Security = config.Security{ + RequireSecureTransport: true, + SSLCA: "wrong path", + SSLCert: "wrong path", + SSLKey: "wrong path", + } + _, err = NewServer(cfg, ts.tidbdrv) + require.Error(t, err) + + // test reload tls fail with/without "error no rollback option" + cfg.Security = config.Security{ + SSLCA: "/tmp/ca-cert-rollback.pem", + SSLCert: "/tmp/server-cert-rollback.pem", + SSLKey: "/tmp/server-key-rollback.pem", + } + server, err := NewServer(cfg, ts.tidbdrv) + require.NoError(t, err) + cli.port = getPortFromTCPAddr(server.listener.Addr()) + go func() { + err := server.Run() + require.NoError(t, err) + }() + defer server.Close() + time.Sleep(time.Millisecond * 100) + connOverrider := func(config *mysql.Config) { + config.TLSConfig = "client-cert-rollback-test" + } + err = cli.runTestTLSConnection(t, connOverrider) + require.NoError(t, err) + os.Remove("/tmp/server-key-rollback.pem") + err = cli.runReloadTLS(t, connOverrider, false) + require.Error(t, err) + tlsCfg := server.getTLSConfig() + require.NotNil(t, tlsCfg) + err = cli.runReloadTLS(t, connOverrider, true) + require.NoError(t, err) + tlsCfg = server.getTLSConfig() + require.Nil(t, tlsCfg) +} diff --git a/server/tidb_test.go b/server/tidb_test.go index e88b190899094..ea7bc6339e2e2 100644 --- a/server/tidb_test.go +++ b/server/tidb_test.go @@ -1062,73 +1062,6 @@ func (ts *tidbTestSerialSuite) TestReloadTLS(c *C) { server.Close() } -func (ts *tidbTestSerialSuite) TestErrorNoRollback(c *C) { - // Generate valid TLS certificates. - caCert, caKey, err := generateCert(0, "TiDB CA", nil, nil, "/tmp/ca-key-rollback.pem", "/tmp/ca-cert-rollback.pem") - c.Assert(err, IsNil) - _, _, err = generateCert(1, "tidb-server", caCert, caKey, "/tmp/server-key-rollback.pem", "/tmp/server-cert-rollback.pem") - c.Assert(err, IsNil) - _, _, err = generateCert(2, "SQL Client Certificate", caCert, caKey, "/tmp/client-key-rollback.pem", "/tmp/client-cert-rollback.pem") - c.Assert(err, IsNil) - err = registerTLSConfig("client-cert-rollback-test", "/tmp/ca-cert-rollback.pem", "/tmp/client-cert-rollback.pem", "/tmp/client-key-rollback.pem", "tidb-server", true) - c.Assert(err, IsNil) - - defer func() { - os.Remove("/tmp/ca-key-rollback.pem") - os.Remove("/tmp/ca-cert-rollback.pem") - - os.Remove("/tmp/server-key-rollback.pem") - os.Remove("/tmp/server-cert-rollback.pem") - os.Remove("/tmp/client-key-rollback.pem") - os.Remove("/tmp/client-cert-rollback.pem") - }() - - cli := newTestServerClient() - cfg := newTestConfig() - cfg.Socket = "" - cfg.Port = cli.port - cfg.Status.ReportStatus = false - - cfg.Security = config.Security{ - RequireSecureTransport: true, - SSLCA: "wrong path", - SSLCert: "wrong path", - SSLKey: "wrong path", - } - _, err = NewServer(cfg, ts.tidbdrv) - c.Assert(err, NotNil) - - // test reload tls fail with/without "error no rollback option" - cfg.Security = config.Security{ - SSLCA: "/tmp/ca-cert-rollback.pem", - SSLCert: "/tmp/server-cert-rollback.pem", - SSLKey: "/tmp/server-key-rollback.pem", - } - server, err := NewServer(cfg, ts.tidbdrv) - c.Assert(err, IsNil) - cli.port = getPortFromTCPAddr(server.listener.Addr()) - go func() { - err := server.Run() - c.Assert(err, IsNil) - }() - defer server.Close() - time.Sleep(time.Millisecond * 100) - connOverrider := func(config *mysql.Config) { - config.TLSConfig = "client-cert-rollback-test" - } - err = cli.runTestTLSConnection(c, connOverrider) - c.Assert(err, IsNil) - os.Remove("/tmp/server-key-rollback.pem") - err = cli.runReloadTLS(c, connOverrider, false) - c.Assert(err, NotNil) - tlsCfg := server.getTLSConfig() - c.Assert(tlsCfg, NotNil) - err = cli.runReloadTLS(c, connOverrider, true) - c.Assert(err, IsNil) - tlsCfg = server.getTLSConfig() - c.Assert(tlsCfg, IsNil) -} - func TestClientWithCollation(t *testing.T) { t.Parallel() ts, cleanup := createTiDBTest(t) From 8adcd51a4f0a54607791c1ff69cd6f3be1be4619 Mon Sep 17 00:00:00 2001 From: yedamo Date: Tue, 9 Nov 2021 22:20:54 +0800 Subject: [PATCH 38/55] server: migrate `TestPrepareCount` --- server/tidb_serial_test.go | 26 ++++++++++++++++++++++++++ server/tidb_test.go | 22 ---------------------- 2 files changed, 26 insertions(+), 22 deletions(-) diff --git a/server/tidb_serial_test.go b/server/tidb_serial_test.go index 0318cb939eb17..60735b31496f4 100644 --- a/server/tidb_serial_test.go +++ b/server/tidb_serial_test.go @@ -15,14 +15,17 @@ package server import ( + "context" "crypto/x509" "os" + "sync/atomic" "testing" "time" "github.com/go-sql-driver/mysql" "github.com/pingcap/errors" "github.com/pingcap/tidb/config" + tmysql "github.com/pingcap/tidb/parser/mysql" "github.com/pingcap/tidb/sessionctx/variable" "github.com/pingcap/tidb/testkit" "github.com/pingcap/tidb/util" @@ -339,3 +342,26 @@ func TestErrorNoRollback(t *testing.T) { tlsCfg = server.getTLSConfig() require.Nil(t, tlsCfg) } + +func TestPrepareCount(t *testing.T) { + ts, cleanup := createTiDBTest(t) + defer cleanup() + + qctx, err := ts.tidbdrv.OpenCtx(uint64(0), 0, uint8(tmysql.DefaultCollationID), "test", nil) + require.NoError(t, err) + prepareCnt := atomic.LoadInt64(&variable.PreparedStmtCount) + ctx := context.Background() + _, err = Execute(ctx, qctx, "use test;") + require.NoError(t, err) + _, err = Execute(ctx, qctx, "drop table if exists t1") + require.NoError(t, err) + _, err = Execute(ctx, qctx, "create table t1 (id int)") + require.NoError(t, err) + stmt, _, _, err := qctx.Prepare("insert into t1 values (?)") + require.NoError(t, err) + require.Equal(t, prepareCnt+1, atomic.LoadInt64(&variable.PreparedStmtCount)) + require.NoError(t, err) + err = qctx.GetStatement(stmt.ID()).Close() + require.NoError(t, err) + require.Equal(t, prepareCnt, atomic.LoadInt64(&variable.PreparedStmtCount)) +} diff --git a/server/tidb_test.go b/server/tidb_test.go index ea7bc6339e2e2..c677ae3d6c4cb 100644 --- a/server/tidb_test.go +++ b/server/tidb_test.go @@ -32,7 +32,6 @@ import ( "os" "path/filepath" "strings" - "sync/atomic" "testing" "time" @@ -47,7 +46,6 @@ import ( "github.com/pingcap/tidb/parser" tmysql "github.com/pingcap/tidb/parser/mysql" "github.com/pingcap/tidb/session" - "github.com/pingcap/tidb/sessionctx/variable" "github.com/pingcap/tidb/store/mockstore" "github.com/pingcap/tidb/testkit" "github.com/pingcap/tidb/util" @@ -1456,26 +1454,6 @@ func TestPessimisticInsertSelectForUpdate(t *testing.T) { require.Nil(t, rs) // should be no delay } -func (ts *tidbTestSerialSuite) TestPrepareCount(c *C) { - qctx, err := ts.tidbdrv.OpenCtx(uint64(0), 0, uint8(tmysql.DefaultCollationID), "test", nil) - c.Assert(err, IsNil) - prepareCnt := atomic.LoadInt64(&variable.PreparedStmtCount) - ctx := context.Background() - _, err = Execute(ctx, qctx, "use test;") - c.Assert(err, IsNil) - _, err = Execute(ctx, qctx, "drop table if exists t1") - c.Assert(err, IsNil) - _, err = Execute(ctx, qctx, "create table t1 (id int)") - c.Assert(err, IsNil) - stmt, _, _, err := qctx.Prepare("insert into t1 values (?)") - c.Assert(err, IsNil) - c.Assert(atomic.LoadInt64(&variable.PreparedStmtCount), Equals, prepareCnt+1) - c.Assert(err, IsNil) - err = qctx.GetStatement(stmt.ID()).Close() - c.Assert(err, IsNil) - c.Assert(atomic.LoadInt64(&variable.PreparedStmtCount), Equals, prepareCnt) -} - type collectorWrapper struct { reporter.TopSQLReporter } From d8029683437114a4720dcafd5cc1d696546cab97 Mon Sep 17 00:00:00 2001 From: yedamo Date: Tue, 9 Nov 2021 22:27:48 +0800 Subject: [PATCH 39/55] server: migrate `TestDefaultCharacterAndCollation` --- server/tidb_serial_test.go | 27 +++++++++++++++++++++++++++ server/tidb_test.go | 24 ------------------------ 2 files changed, 27 insertions(+), 24 deletions(-) diff --git a/server/tidb_serial_test.go b/server/tidb_serial_test.go index 60735b31496f4..4162123a30971 100644 --- a/server/tidb_serial_test.go +++ b/server/tidb_serial_test.go @@ -29,6 +29,7 @@ import ( "github.com/pingcap/tidb/sessionctx/variable" "github.com/pingcap/tidb/testkit" "github.com/pingcap/tidb/util" + "github.com/pingcap/tidb/util/collate" "github.com/stretchr/testify/require" ) @@ -365,3 +366,29 @@ func TestPrepareCount(t *testing.T) { require.NoError(t, err) require.Equal(t, prepareCnt, atomic.LoadInt64(&variable.PreparedStmtCount)) } + +func TestDefaultCharacterAndCollation(t *testing.T) { + ts, cleanup := createTiDBTest(t) + defer cleanup() + + // issue #21194 + collate.SetNewCollationEnabledForTest(true) + defer collate.SetNewCollationEnabledForTest(false) + // 255 is the collation id of mysql client 8 default collation_connection + qctx, err := ts.tidbdrv.OpenCtx(uint64(0), 0, uint8(255), "test", nil) + require.NoError(t, err) + testCase := []struct { + variable string + except string + }{ + {"collation_connection", "utf8mb4_bin"}, + {"character_set_connection", "utf8mb4"}, + {"character_set_client", "utf8mb4"}, + } + + for _, tc := range testCase { + sVars, b := qctx.GetSessionVars().GetSystemVar(tc.variable) + require.True(t, b) + require.Equal(t, tc.except, sVars) + } +} diff --git a/server/tidb_test.go b/server/tidb_test.go index c677ae3d6c4cb..c9d08980c5fac 100644 --- a/server/tidb_test.go +++ b/server/tidb_test.go @@ -49,7 +49,6 @@ import ( "github.com/pingcap/tidb/store/mockstore" "github.com/pingcap/tidb/testkit" "github.com/pingcap/tidb/util" - "github.com/pingcap/tidb/util/collate" "github.com/pingcap/tidb/util/logutil" "github.com/pingcap/tidb/util/plancodec" "github.com/pingcap/tidb/util/topsql/reporter" @@ -1405,29 +1404,6 @@ func TestGracefulShutdown(t *testing.T) { require.Regexp(t, ".*connect: connection refused", err.Error()) } -func (ts *tidbTestSerialSuite) TestDefaultCharacterAndCollation(c *C) { - // issue #21194 - collate.SetNewCollationEnabledForTest(true) - defer collate.SetNewCollationEnabledForTest(false) - // 255 is the collation id of mysql client 8 default collation_connection - qctx, err := ts.tidbdrv.OpenCtx(uint64(0), 0, uint8(255), "test", nil) - c.Assert(err, IsNil) - testCase := []struct { - variable string - except string - }{ - {"collation_connection", "utf8mb4_bin"}, - {"character_set_connection", "utf8mb4"}, - {"character_set_client", "utf8mb4"}, - } - - for _, t := range testCase { - sVars, b := qctx.GetSessionVars().GetSystemVar(t.variable) - c.Assert(b, IsTrue) - c.Assert(sVars, Equals, t.except) - } -} - func TestPessimisticInsertSelectForUpdate(t *testing.T) { t.Parallel() ts, cleanup := createTiDBTest(t) From e1e786bbfe8ab55e67abfc518c3f70c726a22af9 Mon Sep 17 00:00:00 2001 From: yedamo Date: Tue, 9 Nov 2021 22:58:17 +0800 Subject: [PATCH 40/55] server: migrate `TestReloadTLS` --- server/tidb_serial_test.go | 105 +++++++++++++++++++++++++++++++++++++ server/tidb_test.go | 102 ----------------------------------- 2 files changed, 105 insertions(+), 102 deletions(-) diff --git a/server/tidb_serial_test.go b/server/tidb_serial_test.go index 4162123a30971..353c9bb44d9c9 100644 --- a/server/tidb_serial_test.go +++ b/server/tidb_serial_test.go @@ -392,3 +392,108 @@ func TestDefaultCharacterAndCollation(t *testing.T) { require.Equal(t, tc.except, sVars) } } + +func TestReloadTLS(t *testing.T) { + ts, cleanup := createTiDBTest(t) + defer cleanup() + + // Generate valid TLS certificates. + caCert, caKey, err := generateCert(0, "TiDB CA", nil, nil, "/tmp/ca-key-reload.pem", "/tmp/ca-cert-reload.pem") + require.NoError(t, err) + _, _, err = generateCert(1, "tidb-server", caCert, caKey, "/tmp/server-key-reload.pem", "/tmp/server-cert-reload.pem") + require.NoError(t, err) + _, _, err = generateCert(2, "SQL Client Certificate", caCert, caKey, "/tmp/client-key-reload.pem", "/tmp/client-cert-reload.pem") + require.NoError(t, err) + err = registerTLSConfig("client-certificate-reload", "/tmp/ca-cert-reload.pem", "/tmp/client-cert-reload.pem", "/tmp/client-key-reload.pem", "tidb-server", true) + require.NoError(t, err) + + defer func() { + os.Remove("/tmp/ca-key-reload.pem") + os.Remove("/tmp/ca-cert-reload.pem") + + os.Remove("/tmp/server-key-reload.pem") + os.Remove("/tmp/server-cert-reload.pem") + os.Remove("/tmp/client-key-reload.pem") + os.Remove("/tmp/client-cert-reload.pem") + }() + + // try old cert used in startup configuration. + cli := newTestingServerClient() + cfg := newTestConfig() + cfg.Socket = "" + cfg.Port = cli.port + cfg.Status.ReportStatus = false + cfg.Security = config.Security{ + SSLCA: "/tmp/ca-cert-reload.pem", + SSLCert: "/tmp/server-cert-reload.pem", + SSLKey: "/tmp/server-key-reload.pem", + } + server, err := NewServer(cfg, ts.tidbdrv) + require.NoError(t, err) + cli.port = getPortFromTCPAddr(server.listener.Addr()) + go func() { + err := server.Run() + require.NoError(t, err) + }() + time.Sleep(time.Millisecond * 100) + // The client provides a valid certificate. + connOverrider := func(config *mysql.Config) { + config.TLSConfig = "client-certificate-reload" + } + err = cli.runTestTLSConnection(t, connOverrider) + require.NoError(t, err) + + // try reload a valid cert. + tlsCfg := server.getTLSConfig() + cert, err := x509.ParseCertificate(tlsCfg.Certificates[0].Certificate[0]) + require.NoError(t, err) + oldExpireTime := cert.NotAfter + _, _, err = generateCert(1, "tidb-server", caCert, caKey, "/tmp/server-key-reload2.pem", "/tmp/server-cert-reload2.pem", func(c *x509.Certificate) { + c.NotBefore = time.Now().Add(-24 * time.Hour).UTC() + c.NotAfter = time.Now().Add(1 * time.Hour).UTC() + }) + require.NoError(t, err) + err = os.Rename("/tmp/server-key-reload2.pem", "/tmp/server-key-reload.pem") + require.NoError(t, err) + err = os.Rename("/tmp/server-cert-reload2.pem", "/tmp/server-cert-reload.pem") + require.NoError(t, err) + connOverrider = func(config *mysql.Config) { + config.TLSConfig = "skip-verify" + } + err = cli.runReloadTLS(t, connOverrider, false) + require.NoError(t, err) + connOverrider = func(config *mysql.Config) { + config.TLSConfig = "client-certificate-reload" + } + err = cli.runTestTLSConnection(t, connOverrider) + require.NoError(t, err) + + tlsCfg = server.getTLSConfig() + cert, err = x509.ParseCertificate(tlsCfg.Certificates[0].Certificate[0]) + require.NoError(t, err) + newExpireTime := cert.NotAfter + require.True(t, newExpireTime.After(oldExpireTime)) + + // try reload a expired cert. + _, _, err = generateCert(1, "tidb-server", caCert, caKey, "/tmp/server-key-reload3.pem", "/tmp/server-cert-reload3.pem", func(c *x509.Certificate) { + c.NotBefore = time.Now().Add(-24 * time.Hour).UTC() + c.NotAfter = c.NotBefore.Add(1 * time.Hour).UTC() + }) + require.NoError(t, err) + err = os.Rename("/tmp/server-key-reload3.pem", "/tmp/server-key-reload.pem") + require.NoError(t, err) + err = os.Rename("/tmp/server-cert-reload3.pem", "/tmp/server-cert-reload.pem") + require.NoError(t, err) + connOverrider = func(config *mysql.Config) { + config.TLSConfig = "skip-verify" + } + err = cli.runReloadTLS(t, connOverrider, false) + require.NoError(t, err) + connOverrider = func(config *mysql.Config) { + config.TLSConfig = "client-certificate-reload" + } + err = cli.runTestTLSConnection(t, connOverrider) + require.NotNil(t, err) + require.Truef(t, util.IsTLSExpiredError(err), "real error is %+v", err) + server.Close() +} diff --git a/server/tidb_test.go b/server/tidb_test.go index c9d08980c5fac..678054ed1c455 100644 --- a/server/tidb_test.go +++ b/server/tidb_test.go @@ -957,108 +957,6 @@ func TestSystemTimeZone(t *testing.T) { tk.MustQuery("select @@system_time_zone").Check(tz1) } -func (ts *tidbTestSerialSuite) TestReloadTLS(c *C) { - // Generate valid TLS certificates. - caCert, caKey, err := generateCert(0, "TiDB CA", nil, nil, "/tmp/ca-key-reload.pem", "/tmp/ca-cert-reload.pem") - c.Assert(err, IsNil) - _, _, err = generateCert(1, "tidb-server", caCert, caKey, "/tmp/server-key-reload.pem", "/tmp/server-cert-reload.pem") - c.Assert(err, IsNil) - _, _, err = generateCert(2, "SQL Client Certificate", caCert, caKey, "/tmp/client-key-reload.pem", "/tmp/client-cert-reload.pem") - c.Assert(err, IsNil) - err = registerTLSConfig("client-certificate-reload", "/tmp/ca-cert-reload.pem", "/tmp/client-cert-reload.pem", "/tmp/client-key-reload.pem", "tidb-server", true) - c.Assert(err, IsNil) - - defer func() { - os.Remove("/tmp/ca-key-reload.pem") - os.Remove("/tmp/ca-cert-reload.pem") - - os.Remove("/tmp/server-key-reload.pem") - os.Remove("/tmp/server-cert-reload.pem") - os.Remove("/tmp/client-key-reload.pem") - os.Remove("/tmp/client-cert-reload.pem") - }() - - // try old cert used in startup configuration. - cli := newTestServerClient() - cfg := newTestConfig() - cfg.Socket = "" - cfg.Port = cli.port - cfg.Status.ReportStatus = false - cfg.Security = config.Security{ - SSLCA: "/tmp/ca-cert-reload.pem", - SSLCert: "/tmp/server-cert-reload.pem", - SSLKey: "/tmp/server-key-reload.pem", - } - server, err := NewServer(cfg, ts.tidbdrv) - c.Assert(err, IsNil) - cli.port = getPortFromTCPAddr(server.listener.Addr()) - go func() { - err := server.Run() - c.Assert(err, IsNil) - }() - time.Sleep(time.Millisecond * 100) - // The client provides a valid certificate. - connOverrider := func(config *mysql.Config) { - config.TLSConfig = "client-certificate-reload" - } - err = cli.runTestTLSConnection(c, connOverrider) - c.Assert(err, IsNil) - - // try reload a valid cert. - tlsCfg := server.getTLSConfig() - cert, err := x509.ParseCertificate(tlsCfg.Certificates[0].Certificate[0]) - c.Assert(err, IsNil) - oldExpireTime := cert.NotAfter - _, _, err = generateCert(1, "tidb-server", caCert, caKey, "/tmp/server-key-reload2.pem", "/tmp/server-cert-reload2.pem", func(c *x509.Certificate) { - c.NotBefore = time.Now().Add(-24 * time.Hour).UTC() - c.NotAfter = time.Now().Add(1 * time.Hour).UTC() - }) - c.Assert(err, IsNil) - err = os.Rename("/tmp/server-key-reload2.pem", "/tmp/server-key-reload.pem") - c.Assert(err, IsNil) - err = os.Rename("/tmp/server-cert-reload2.pem", "/tmp/server-cert-reload.pem") - c.Assert(err, IsNil) - connOverrider = func(config *mysql.Config) { - config.TLSConfig = "skip-verify" - } - err = cli.runReloadTLS(c, connOverrider, false) - c.Assert(err, IsNil) - connOverrider = func(config *mysql.Config) { - config.TLSConfig = "client-certificate-reload" - } - err = cli.runTestTLSConnection(c, connOverrider) - c.Assert(err, IsNil) - - tlsCfg = server.getTLSConfig() - cert, err = x509.ParseCertificate(tlsCfg.Certificates[0].Certificate[0]) - c.Assert(err, IsNil) - newExpireTime := cert.NotAfter - c.Assert(newExpireTime.After(oldExpireTime), IsTrue) - - // try reload a expired cert. - _, _, err = generateCert(1, "tidb-server", caCert, caKey, "/tmp/server-key-reload3.pem", "/tmp/server-cert-reload3.pem", func(c *x509.Certificate) { - c.NotBefore = time.Now().Add(-24 * time.Hour).UTC() - c.NotAfter = c.NotBefore.Add(1 * time.Hour).UTC() - }) - c.Assert(err, IsNil) - err = os.Rename("/tmp/server-key-reload3.pem", "/tmp/server-key-reload.pem") - c.Assert(err, IsNil) - err = os.Rename("/tmp/server-cert-reload3.pem", "/tmp/server-cert-reload.pem") - c.Assert(err, IsNil) - connOverrider = func(config *mysql.Config) { - config.TLSConfig = "skip-verify" - } - err = cli.runReloadTLS(c, connOverrider, false) - c.Assert(err, IsNil) - connOverrider = func(config *mysql.Config) { - config.TLSConfig = "client-certificate-reload" - } - err = cli.runTestTLSConnection(c, connOverrider) - c.Assert(err, NotNil) - c.Assert(util.IsTLSExpiredError(err), IsTrue, Commentf("real error is %+v", err)) - server.Close() -} - func TestClientWithCollation(t *testing.T) { t.Parallel() ts, cleanup := createTiDBTest(t) From 7f4524e4130811460d345febb53a3d2d9eb6e7d8 Mon Sep 17 00:00:00 2001 From: yedamo Date: Tue, 9 Nov 2021 23:00:53 +0800 Subject: [PATCH 41/55] server: remove tidbTestSerialSuite --- server/tidb_test.go | 5 ----- 1 file changed, 5 deletions(-) diff --git a/server/tidb_test.go b/server/tidb_test.go index 678054ed1c455..b3ef95d4c4e1b 100644 --- a/server/tidb_test.go +++ b/server/tidb_test.go @@ -127,10 +127,6 @@ func createTiDBTest(t *testing.T) (*tidbTest, func()) { return &tidbTest{base}, cleanup } -type tidbTestSerialSuite struct { - *tidbTestSuiteBase -} - type tidbTestTopSQLSuite struct { *tidbTestSuiteBase } @@ -150,7 +146,6 @@ func newTiDBTestSuiteBase() *tidbTestSuiteBase { } var _ = Suite(&tidbTestSuite{newTiDBTestSuiteBase()}) -var _ = SerialSuites(&tidbTestSerialSuite{newTiDBTestSuiteBase()}) var _ = SerialSuites(&tidbTestTopSQLSuite{newTiDBTestSuiteBase()}) func (ts *tidbTestSuite) SetUpSuite(c *C) { From 624c8bf63cfc307a94343bfdbc958ad29aba0637 Mon Sep 17 00:00:00 2001 From: yedamo Date: Tue, 9 Nov 2021 23:03:42 +0800 Subject: [PATCH 42/55] server: remove unused runTestTLSConnection/runReloadTLS --- server/server_test.go | 30 ------------------------------ 1 file changed, 30 deletions(-) diff --git a/server/server_test.go b/server/server_test.go index 6ed434bf704f1..5102728bc204a 100644 --- a/server/server_test.go +++ b/server/server_test.go @@ -2017,21 +2017,6 @@ func (cli *testingServerClient) runTestStmtCount(t *testing.T) { }) } -func (cli *testServerClient) runTestTLSConnection(t *C, overrider configOverrider) error { - dsn := cli.getDSN(overrider) - db, err := sql.Open("mysql", dsn) - t.Assert(err, IsNil) - defer func() { - err := db.Close() - t.Assert(err, IsNil) - }() - _, err = db.Exec("USE test") - if err != nil { - return errors.Annotate(err, "dsn:"+dsn) - } - return err -} - func (cli *testingServerClient) runTestTLSConnection(t *testing.T, overrider configOverrider) error { dsn := cli.getDSN(overrider) db, err := sql.Open("mysql", dsn) @@ -2047,21 +2032,6 @@ func (cli *testingServerClient) runTestTLSConnection(t *testing.T, overrider con return err } -func (cli *testServerClient) runReloadTLS(t *C, overrider configOverrider, errorNoRollback bool) error { - db, err := sql.Open("mysql", cli.getDSN(overrider)) - t.Assert(err, IsNil) - defer func() { - err := db.Close() - t.Assert(err, IsNil) - }() - sql := "alter instance reload tls" - if errorNoRollback { - sql += " no rollback on error" - } - _, err = db.Exec(sql) - return err -} - func (cli *testingServerClient) runReloadTLS(t *testing.T, overrider configOverrider, errorNoRollback bool) error { db, err := sql.Open("mysql", cli.getDSN(overrider)) require.NoError(t, err) From a0bca550156c9aba1197e1b5422211c25a6fa156 Mon Sep 17 00:00:00 2001 From: yedamo Date: Sat, 13 Nov 2021 23:51:46 +0800 Subject: [PATCH 43/55] server: migrate `TestClientErrors` --- server/server_test.go | 14 +++++++------- server/tidb_test.go | 6 ++++-- 2 files changed, 11 insertions(+), 9 deletions(-) diff --git a/server/server_test.go b/server/server_test.go index 5102728bc204a..2b6a803ba19a2 100644 --- a/server/server_test.go +++ b/server/server_test.go @@ -2188,8 +2188,8 @@ func (cli *testServerClient) runTestInitConnect(c *C) { // Client errors are only incremented when using the TiDB Server protocol, // and not internal SQL statements. Thus, this test is in the server-test suite. -func (cli *testServerClient) runTestInfoschemaClientErrors(t *C) { - cli.runTestsOnNewDB(t, nil, "clientErrors", func(dbt *DBTest) { +func (cli *testingServerClient) runTestInfoschemaClientErrors(t *testing.T) { + cli.runTestsOnNewDB(t, nil, "clientErrors", func(dbt *testkit.DBTestKit) { clientErrors := []struct { stmt string @@ -2221,7 +2221,7 @@ func (cli *testServerClient) runTestInfoschemaClientErrors(t *C) { for _, tbl := range sources { var errors, warnings int - rows := dbt.mustQuery("SELECT SUM(error_count), SUM(warning_count) FROM information_schema."+tbl+" WHERE error_number = ? GROUP BY error_number", test.errCode) + rows := dbt.MustQuery("SELECT SUM(error_count), SUM(warning_count) FROM information_schema."+tbl+" WHERE error_number = ? GROUP BY error_number", test.errCode) if rows.Next() { rows.Scan(&errors, &warnings) } @@ -2234,7 +2234,7 @@ func (cli *testServerClient) runTestInfoschemaClientErrors(t *C) { warnings++ } var err error - rows, err = dbt.db.Query(test.stmt) + rows, err = dbt.GetDB().Query(test.stmt) if err == nil { // make sure to read the result since the error/warnings are populated in the network send code. if rows.Next() { @@ -2244,13 +2244,13 @@ func (cli *testServerClient) runTestInfoschemaClientErrors(t *C) { rows.Close() } var newErrors, newWarnings int - rows = dbt.mustQuery("SELECT SUM(error_count), SUM(warning_count) FROM information_schema."+tbl+" WHERE error_number = ? GROUP BY error_number", test.errCode) + rows = dbt.MustQuery("SELECT SUM(error_count), SUM(warning_count) FROM information_schema."+tbl+" WHERE error_number = ? GROUP BY error_number", test.errCode) if rows.Next() { rows.Scan(&newErrors, &newWarnings) } rows.Close() - dbt.Check(newErrors, Equals, errors) - dbt.Check(newWarnings, Equals, warnings, Commentf("source=information_schema.%s code=%d statement=%s", tbl, test.errCode, test.stmt)) + require.Equal(t, errors, newErrors) + require.Equalf(t, warnings, newWarnings, "source=information_schema.%s code=%d statement=%s", tbl, test.errCode, test.stmt) } } diff --git a/server/tidb_test.go b/server/tidb_test.go index b3ef95d4c4e1b..dadd44311551d 100644 --- a/server/tidb_test.go +++ b/server/tidb_test.go @@ -1149,8 +1149,10 @@ func TestFieldList(t *testing.T) { require.Equal(t, columnAsName, cols[0].Name) } -func (ts *tidbTestSuite) TestClientErrors(c *C) { - ts.runTestInfoschemaClientErrors(c) +func TestClientErrors(t *testing.T) { + ts, cleanup := createTiDBTest(t) + defer cleanup() + ts.runTestInfoschemaClientErrors(t) } func (ts *tidbTestSuite) TestInitConnect(c *C) { From c6c42877cac20f60c3cfbe3b24f91c5c67e92dd1 Mon Sep 17 00:00:00 2001 From: yedamo Date: Sun, 14 Nov 2021 00:01:01 +0800 Subject: [PATCH 44/55] server: migrate `TestInitConnect` --- server/server_test.go | 60 +++++++++++++++++++++---------------------- server/tidb_test.go | 6 +++-- 2 files changed, 34 insertions(+), 32 deletions(-) diff --git a/server/server_test.go b/server/server_test.go index 2b6a803ba19a2..6ad50987a26c3 100644 --- a/server/server_test.go +++ b/server/server_test.go @@ -2131,59 +2131,59 @@ func (cli *testServerClient) waitUntilServerOnline() { } } -func (cli *testServerClient) runTestInitConnect(c *C) { +func (cli *testingServerClient) runTestInitConnect(t *testing.T) { - cli.runTests(c, nil, func(dbt *DBTest) { - dbt.mustExec(`SET GLOBAL init_connect="insert into test.ts VALUES (NOW());SET @a=1;"`) - dbt.mustExec(`CREATE USER init_nonsuper`) - dbt.mustExec(`CREATE USER init_super`) - dbt.mustExec(`GRANT SELECT, INSERT, DROP ON test.* TO init_nonsuper`) - dbt.mustExec(`GRANT SELECT, INSERT, DROP, SUPER ON *.* TO init_super`) - dbt.mustExec(`CREATE TABLE ts (a TIMESTAMP)`) + cli.runTests(t, nil, func(dbt *testkit.DBTestKit) { + dbt.MustExec(`SET GLOBAL init_connect="insert into test.ts VALUES (NOW());SET @a=1;"`) + dbt.MustExec(`CREATE USER init_nonsuper`) + dbt.MustExec(`CREATE USER init_super`) + dbt.MustExec(`GRANT SELECT, INSERT, DROP ON test.* TO init_nonsuper`) + dbt.MustExec(`GRANT SELECT, INSERT, DROP, SUPER ON *.* TO init_super`) + dbt.MustExec(`CREATE TABLE ts (a TIMESTAMP)`) }) // test init_nonsuper - cli.runTests(c, func(config *mysql.Config) { + cli.runTests(t, func(config *mysql.Config) { config.User = "init_nonsuper" - }, func(dbt *DBTest) { - rows := dbt.mustQuery(`SELECT @a`) - c.Assert(rows.Next(), IsTrue) + }, func(dbt *testkit.DBTestKit) { + rows := dbt.MustQuery(`SELECT @a`) + require.True(t, rows.Next()) var a int err := rows.Scan(&a) - c.Assert(err, IsNil) - dbt.Check(a, Equals, 1) - c.Assert(rows.Close(), IsNil) + require.NoError(t, err) + require.Equal(t, 1, a) + require.NoError(t, rows.Close()) }) // test init_super - cli.runTests(c, func(config *mysql.Config) { + cli.runTests(t, func(config *mysql.Config) { config.User = "init_super" - }, func(dbt *DBTest) { - rows := dbt.mustQuery(`SELECT IFNULL(@a,"")`) - c.Assert(rows.Next(), IsTrue) + }, func(dbt *testkit.DBTestKit) { + rows := dbt.MustQuery(`SELECT IFNULL(@a,"")`) + require.True(t, rows.Next()) var a string err := rows.Scan(&a) - c.Assert(err, IsNil) - dbt.Check(a, Equals, "") // null - c.Assert(rows.Close(), IsNil) + require.NoError(t, err) + require.Equal(t, "", a) + require.NoError(t, rows.Close()) // change the init-connect to invalid. - dbt.mustExec(`SET GLOBAL init_connect="invalidstring"`) + dbt.MustExec(`SET GLOBAL init_connect="invalidstring"`) }) // set global init_connect to empty to avoid fail other tests - defer cli.runTests(c, func(config *mysql.Config) { + defer cli.runTests(t, func(config *mysql.Config) { config.User = "init_super" - }, func(dbt *DBTest) { + }, func(dbt *testkit.DBTestKit) { // set init_connect to empty to avoid fail other tests - dbt.mustExec(`SET GLOBAL init_connect=""`) + dbt.MustExec(`SET GLOBAL init_connect=""`) }) db, err := sql.Open("mysql", cli.getDSN(func(config *mysql.Config) { config.User = "init_nonsuper" })) - c.Assert(err, IsNil, Commentf("Error connecting")) // doesn't fail because of lazy loading - defer db.Close() // may already be closed - _, err = db.Exec("SELECT 1") // fails because of init sql - c.Assert(err, NotNil) + require.NoErrorf(t, err, "Error connecting") // doesn't fail because of lazy loading + defer db.Close() // may already be closed + _, err = db.Exec("SELECT 1") // fails because of init sql + require.Error(t, err) } // Client errors are only incremented when using the TiDB Server protocol, diff --git a/server/tidb_test.go b/server/tidb_test.go index dadd44311551d..f4e333a1f006d 100644 --- a/server/tidb_test.go +++ b/server/tidb_test.go @@ -1155,8 +1155,10 @@ func TestClientErrors(t *testing.T) { ts.runTestInfoschemaClientErrors(t) } -func (ts *tidbTestSuite) TestInitConnect(c *C) { - ts.runTestInitConnect(c) +func TestInitConnect(t *testing.T) { + ts, cleanup := createTiDBTest(t) + defer cleanup() + ts.runTestInitConnect(t) } func (ts *tidbTestSuite) TestSumAvg(c *C) { From 3d102dc3746cd8c1661c6fb01ff47e5f8648f5c4 Mon Sep 17 00:00:00 2001 From: yedamo Date: Sun, 14 Nov 2021 00:05:25 +0800 Subject: [PATCH 45/55] server: migrate `TestSumAvg` --- server/server_test.go | 32 ++++++++++++++++---------------- server/tidb_test.go | 10 +++++++--- 2 files changed, 23 insertions(+), 19 deletions(-) diff --git a/server/server_test.go b/server/server_test.go index 6ad50987a26c3..c15e745d36832 100644 --- a/server/server_test.go +++ b/server/server_test.go @@ -2047,25 +2047,25 @@ func (cli *testingServerClient) runReloadTLS(t *testing.T, overrider configOverr return err } -func (cli *testServerClient) runTestSumAvg(c *C) { - cli.runTests(c, nil, func(dbt *DBTest) { - dbt.mustExec("create table sumavg (a int, b decimal, c double)") - dbt.mustExec("insert sumavg values (1, 1, 1)") - rows := dbt.mustQuery("select sum(a), sum(b), sum(c) from sumavg") - c.Assert(rows.Next(), IsTrue) +func (cli *testingServerClient) runTestSumAvg(t *testing.T) { + cli.runTests(t, nil, func(dbt *testkit.DBTestKit) { + dbt.MustExec("create table sumavg (a int, b decimal, c double)") + dbt.MustExec("insert sumavg values (1, 1, 1)") + rows := dbt.MustQuery("select sum(a), sum(b), sum(c) from sumavg") + require.True(t, rows.Next()) var outA, outB, outC float64 err := rows.Scan(&outA, &outB, &outC) - c.Assert(err, IsNil) - c.Assert(outA, Equals, 1.0) - c.Assert(outB, Equals, 1.0) - c.Assert(outC, Equals, 1.0) - rows = dbt.mustQuery("select avg(a), avg(b), avg(c) from sumavg") - c.Assert(rows.Next(), IsTrue) + require.NoError(t, err) + require.Equal(t, 1.0, outA) + require.Equal(t, 1.0, outB) + require.Equal(t, 1.0, outC) + rows = dbt.MustQuery("select avg(a), avg(b), avg(c) from sumavg") + require.True(t, rows.Next()) err = rows.Scan(&outA, &outB, &outC) - c.Assert(err, IsNil) - c.Assert(outA, Equals, 1.0) - c.Assert(outB, Equals, 1.0) - c.Assert(outC, Equals, 1.0) + require.NoError(t, err) + require.Equal(t, 1.0, outA) + require.Equal(t, 1.0, outB) + require.Equal(t, 1.0, outC) }) } diff --git a/server/tidb_test.go b/server/tidb_test.go index f4e333a1f006d..081dcb9500daf 100644 --- a/server/tidb_test.go +++ b/server/tidb_test.go @@ -1150,20 +1150,24 @@ func TestFieldList(t *testing.T) { } func TestClientErrors(t *testing.T) { + t.Parallel() ts, cleanup := createTiDBTest(t) defer cleanup() ts.runTestInfoschemaClientErrors(t) } func TestInitConnect(t *testing.T) { + t.Parallel() ts, cleanup := createTiDBTest(t) defer cleanup() ts.runTestInitConnect(t) } -func (ts *tidbTestSuite) TestSumAvg(c *C) { - c.Parallel() - ts.runTestSumAvg(c) +func TestSumAvg(t *testing.T) { + t.Parallel() + ts, cleanup := createTiDBTest(t) + defer cleanup() + ts.runTestSumAvg(t) } func TestNullFlag(t *testing.T) { From ed2951ecd0c91b74bdb6d55e52b43f686ad199c0 Mon Sep 17 00:00:00 2001 From: yedamo Date: Sun, 14 Nov 2021 15:15:26 +0800 Subject: [PATCH 46/55] server: migrate `TestOnlySocket` --- server/tidb_test.go | 116 +++++++++++++++++++++++--------------------- 1 file changed, 60 insertions(+), 56 deletions(-) diff --git a/server/tidb_test.go b/server/tidb_test.go index 081dcb9500daf..d29690c34ed7d 100644 --- a/server/tidb_test.go +++ b/server/tidb_test.go @@ -691,55 +691,59 @@ func (ts *tidbTestSuite) TestSocketAndIp(c *C) { } // TestOnlySocket for server configuration without network interface for mysql clients -func (ts *tidbTestSuite) TestOnlySocket(c *C) { +func TestOnlySocket(t *testing.T) { osTempDir := os.TempDir() tempDir, err := os.MkdirTemp(osTempDir, "tidb-test.*.socket") - c.Assert(err, IsNil) + require.NoError(t, err) socketFile := tempDir + "/tidbtest.sock" // Unix Socket does not work on Windows, so '/' should be OK defer os.RemoveAll(tempDir) - cli := newTestServerClient() + + cli := newTestingServerClient() cfg := newTestConfig() cfg.Socket = socketFile cfg.Host = "" // No network interface listening for mysql traffic cfg.Status.ReportStatus = false + ts, cleanup := createTiDBTest(t) + defer cleanup() + server, err := NewServer(cfg, ts.tidbdrv) - c.Assert(err, IsNil) + require.NoError(t, err) go func() { err := server.Run() - c.Assert(err, IsNil) + require.NoError(t, err) }() time.Sleep(time.Millisecond * 100) defer server.Close() - c.Assert(server.listener, IsNil) - c.Assert(server.socket, NotNil) + require.Nil(t, server.listener) + require.NotNil(t, server.socket) // Test with Socket connection + Setup user1@% for all host access defer func() { - cli.runTests(c, func(config *mysql.Config) { + cli.runTests(t, func(config *mysql.Config) { config.User = "root" config.Net = "unix" config.Addr = socketFile }, - func(dbt *DBTest) { - dbt.mustQuery("DROP USER IF EXISTS 'user1'@'%'") - dbt.mustQuery("DROP USER IF EXISTS 'user1'@'localhost'") - dbt.mustQuery("DROP USER IF EXISTS 'user1'@'127.0.0.1'") + func(dbt *testkit.DBTestKit) { + dbt.MustQuery("DROP USER IF EXISTS 'user1'@'%'") + dbt.MustQuery("DROP USER IF EXISTS 'user1'@'localhost'") + dbt.MustQuery("DROP USER IF EXISTS 'user1'@'127.0.0.1'") }) }() - cli.runTests(c, func(config *mysql.Config) { + cli.runTests(t, func(config *mysql.Config) { config.User = "root" config.Net = "unix" config.Addr = socketFile config.DBName = "test" }, - func(dbt *DBTest) { - rows := dbt.mustQuery("select user()") - cli.checkRows(c, rows, "root@localhost") - rows = dbt.mustQuery("show grants") - cli.checkRows(c, rows, "GRANT ALL PRIVILEGES ON *.* TO 'root'@'%' WITH GRANT OPTION") - dbt.mustQuery("CREATE USER user1@'%'") - dbt.mustQuery("GRANT SELECT ON test.* TO user1@'%'") + func(dbt *testkit.DBTestKit) { + rows := dbt.MustQuery("select user()") + cli.checkRows(t, rows, "root@localhost") + rows = dbt.MustQuery("show grants") + cli.checkRows(t, rows, "GRANT ALL PRIVILEGES ON *.* TO 'root'@'%' WITH GRANT OPTION") + dbt.MustQuery("CREATE USER user1@'%'") + dbt.MustQuery("GRANT SELECT ON test.* TO user1@'%'") }) // Test with Network interface connection with all hosts, should fail since server not configured db, err := sql.Open("mysql", cli.getDSN(func(config *mysql.Config) { @@ -747,86 +751,86 @@ func (ts *tidbTestSuite) TestOnlySocket(c *C) { config.DBName = "test" config.Addr = "127.0.0.1" })) - c.Assert(err, IsNil, Commentf("Connect succeeded when not configured!?!")) + require.NoErrorf(t, err, "Connect succeeded when not configured!?!") defer db.Close() db, err = sql.Open("mysql", cli.getDSN(func(config *mysql.Config) { config.User = "user1" config.DBName = "test" config.Addr = "127.0.0.1" })) - c.Assert(err, IsNil, Commentf("Connect succeeded when not configured!?!")) + require.NoErrorf(t, err, "Connect succeeded when not configured!?!") defer db.Close() // Test with unix domain socket file connection with all hosts - cli.runTests(c, func(config *mysql.Config) { + cli.runTests(t, func(config *mysql.Config) { config.Net = "unix" config.Addr = socketFile config.User = "user1" config.DBName = "test" }, - func(dbt *DBTest) { - rows := dbt.mustQuery("select user()") - cli.checkRows(c, rows, "user1@localhost") - rows = dbt.mustQuery("show grants") - cli.checkRows(c, rows, "GRANT USAGE ON *.* TO 'user1'@'%'\nGRANT SELECT ON test.* TO 'user1'@'%'") + func(dbt *testkit.DBTestKit) { + rows := dbt.MustQuery("select user()") + cli.checkRows(t, rows, "user1@localhost") + rows = dbt.MustQuery("show grants") + cli.checkRows(t, rows, "GRANT USAGE ON *.* TO 'user1'@'%'\nGRANT SELECT ON test.* TO 'user1'@'%'") }) // Setup user1@127.0.0.1 for loop back network interface access - cli.runTests(c, func(config *mysql.Config) { + cli.runTests(t, func(config *mysql.Config) { config.Net = "unix" config.Addr = socketFile config.User = "root" config.DBName = "test" }, - func(dbt *DBTest) { - rows := dbt.mustQuery("select user()") + func(dbt *testkit.DBTestKit) { + rows := dbt.MustQuery("select user()") // NOTICE: this is not compatible with MySQL! (MySQL would report user1@localhost also for 127.0.0.1) - cli.checkRows(c, rows, "root@localhost") - rows = dbt.mustQuery("show grants") - cli.checkRows(c, rows, "GRANT ALL PRIVILEGES ON *.* TO 'root'@'%' WITH GRANT OPTION") - dbt.mustQuery("CREATE USER user1@127.0.0.1") - dbt.mustQuery("GRANT SELECT,INSERT ON test.* TO user1@'127.0.0.1'") + cli.checkRows(t, rows, "root@localhost") + rows = dbt.MustQuery("show grants") + cli.checkRows(t, rows, "GRANT ALL PRIVILEGES ON *.* TO 'root'@'%' WITH GRANT OPTION") + dbt.MustQuery("CREATE USER user1@127.0.0.1") + dbt.MustQuery("GRANT SELECT,INSERT ON test.* TO user1@'127.0.0.1'") }) // Test with unix domain socket file connection with all hosts - cli.runTests(c, func(config *mysql.Config) { + cli.runTests(t, func(config *mysql.Config) { config.Net = "unix" config.Addr = socketFile config.User = "user1" config.DBName = "test" }, - func(dbt *DBTest) { - rows := dbt.mustQuery("select user()") - cli.checkRows(c, rows, "user1@localhost") - rows = dbt.mustQuery("show grants") - cli.checkRows(c, rows, "GRANT USAGE ON *.* TO 'user1'@'%'\nGRANT SELECT ON test.* TO 'user1'@'%'") + func(dbt *testkit.DBTestKit) { + rows := dbt.MustQuery("select user()") + cli.checkRows(t, rows, "user1@localhost") + rows = dbt.MustQuery("show grants") + cli.checkRows(t, rows, "GRANT USAGE ON *.* TO 'user1'@'%'\nGRANT SELECT ON test.* TO 'user1'@'%'") }) // Setup user1@localhost for socket (and if MySQL compatible; loop back network interface access) - cli.runTests(c, func(config *mysql.Config) { + cli.runTests(t, func(config *mysql.Config) { config.Net = "unix" config.Addr = socketFile config.User = "root" config.DBName = "test" }, - func(dbt *DBTest) { - rows := dbt.mustQuery("select user()") - cli.checkRows(c, rows, "root@localhost") - rows = dbt.mustQuery("show grants") - cli.checkRows(c, rows, "GRANT ALL PRIVILEGES ON *.* TO 'root'@'%' WITH GRANT OPTION") - dbt.mustQuery("CREATE USER user1@localhost") - dbt.mustQuery("GRANT SELECT,INSERT,UPDATE,DELETE ON test.* TO user1@localhost") + func(dbt *testkit.DBTestKit) { + rows := dbt.MustQuery("select user()") + cli.checkRows(t, rows, "root@localhost") + rows = dbt.MustQuery("show grants") + cli.checkRows(t, rows, "GRANT ALL PRIVILEGES ON *.* TO 'root'@'%' WITH GRANT OPTION") + dbt.MustQuery("CREATE USER user1@localhost") + dbt.MustQuery("GRANT SELECT,INSERT,UPDATE,DELETE ON test.* TO user1@localhost") }) // Test with unix domain socket file connection with all hosts - cli.runTests(c, func(config *mysql.Config) { + cli.runTests(t, func(config *mysql.Config) { config.Net = "unix" config.Addr = socketFile config.User = "user1" config.DBName = "test" }, - func(dbt *DBTest) { - rows := dbt.mustQuery("select user()") - cli.checkRows(c, rows, "user1@localhost") - rows = dbt.mustQuery("show grants") - cli.checkRows(c, rows, "GRANT USAGE ON *.* TO 'user1'@'localhost'\nGRANT SELECT,INSERT,UPDATE,DELETE ON test.* TO 'user1'@'localhost'") + func(dbt *testkit.DBTestKit) { + rows := dbt.MustQuery("select user()") + cli.checkRows(t, rows, "user1@localhost") + rows = dbt.MustQuery("show grants") + cli.checkRows(t, rows, "GRANT USAGE ON *.* TO 'user1'@'localhost'\nGRANT SELECT,INSERT,UPDATE,DELETE ON test.* TO 'user1'@'localhost'") }) } From 6bdeb730da4f7ad2f44e726c512e40d7765adc55 Mon Sep 17 00:00:00 2001 From: yedamo Date: Sun, 14 Nov 2021 15:31:01 +0800 Subject: [PATCH 47/55] server: migrate `TestSocketAndIp` --- server/server_test.go | 45 ------------- server/tidb_test.go | 144 ++++++++++++++++++++++-------------------- 2 files changed, 74 insertions(+), 115 deletions(-) diff --git a/server/server_test.go b/server/server_test.go index c15e745d36832..b8f928f4e0c37 100644 --- a/server/server_test.go +++ b/server/server_test.go @@ -122,21 +122,6 @@ func (cli *testServerClient) getDSN(overriders ...configOverrider) string { return config.FormatDSN() } -// runTests runs tests using the default database `test`. -func (cli *testServerClient) runTests(c *C, overrider configOverrider, tests ...func(dbt *DBTest)) { - db, err := sql.Open("mysql", cli.getDSN(overrider)) - c.Assert(err, IsNil, Commentf("Error connecting")) - defer func() { - err := db.Close() - c.Assert(err, IsNil) - }() - - dbt := &DBTest{c, db} - for _, test := range tests { - test(dbt) - } -} - // runTests runs tests using the default database `test`. func (cli *testingServerClient) runTests(t *testing.T, overrider configOverrider, tests ...func(dbt *testkit.DBTestKit)) { db, err := sql.Open("mysql", cli.getDSN(overrider)) @@ -951,36 +936,6 @@ func (cli *testingServerClient) runTestLoadDataForListColumnPartition2(t *testin }) } -func (cli *testServerClient) checkRows(c *C, rows *sql.Rows, expectedRows ...string) { - buf := bytes.NewBuffer(nil) - result := make([]string, 0, 2) - for rows.Next() { - cols, err := rows.Columns() - c.Assert(err, IsNil) - rawResult := make([][]byte, len(cols)) - dest := make([]interface{}, len(cols)) - for i := range rawResult { - dest[i] = &rawResult[i] - } - - err = rows.Scan(dest...) - c.Assert(err, IsNil) - buf.Reset() - for i, raw := range rawResult { - if i > 0 { - buf.WriteString(" ") - } - if raw == nil { - buf.WriteString("") - } else { - buf.WriteString(string(raw)) - } - } - result = append(result, buf.String()) - } - c.Assert(strings.Join(result, "\n"), Equals, strings.Join(expectedRows, "\n")) -} - func (cli *testingServerClient) checkRows(t *testing.T, rows *sql.Rows, expectedRows ...string) { buf := bytes.NewBuffer(nil) result := make([]string, 0, 2) diff --git a/server/tidb_test.go b/server/tidb_test.go index d29690c34ed7d..e426ce8d28963 100644 --- a/server/tidb_test.go +++ b/server/tidb_test.go @@ -533,24 +533,28 @@ func (ts *tidbTestSuite) TestSocket(c *C) { } -func (ts *tidbTestSuite) TestSocketAndIp(c *C) { +func TestSocketAndIp(t *testing.T) { osTempDir := os.TempDir() tempDir, err := os.MkdirTemp(osTempDir, "tidb-test.*.socket") - c.Assert(err, IsNil) + require.NoError(t, err) socketFile := tempDir + "/tidbtest.sock" // Unix Socket does not work on Windows, so '/' should be OK defer os.RemoveAll(tempDir) - cli := newTestServerClient() + + cli := newTestingServerClient() cfg := newTestConfig() cfg.Socket = socketFile cfg.Port = cli.port cfg.Status.ReportStatus = false + ts, cleanup := createTiDBTest(t) + defer cleanup() + server, err := NewServer(cfg, ts.tidbdrv) - c.Assert(err, IsNil) + require.NoError(t, err) cli.port = getPortFromTCPAddr(server.listener.Addr()) go func() { err := server.Run() - c.Assert(err, IsNil) + require.NoError(t, err) }() time.Sleep(time.Millisecond * 100) defer server.Close() @@ -558,134 +562,134 @@ func (ts *tidbTestSuite) TestSocketAndIp(c *C) { // Test with Socket connection + Setup user1@% for all host access cli.port = getPortFromTCPAddr(server.listener.Addr()) defer func() { - cli.runTests(c, func(config *mysql.Config) { + cli.runTests(t, func(config *mysql.Config) { config.User = "root" }, - func(dbt *DBTest) { - dbt.mustQuery("DROP USER IF EXISTS 'user1'@'%'") - dbt.mustQuery("DROP USER IF EXISTS 'user1'@'localhost'") - dbt.mustQuery("DROP USER IF EXISTS 'user1'@'127.0.0.1'") + func(dbt *testkit.DBTestKit) { + dbt.MustQuery("DROP USER IF EXISTS 'user1'@'%'") + dbt.MustQuery("DROP USER IF EXISTS 'user1'@'localhost'") + dbt.MustQuery("DROP USER IF EXISTS 'user1'@'127.0.0.1'") }) }() - cli.runTests(c, func(config *mysql.Config) { + cli.runTests(t, func(config *mysql.Config) { config.User = "root" config.Net = "unix" config.Addr = socketFile config.DBName = "test" }, - func(dbt *DBTest) { - rows := dbt.mustQuery("select user()") - cli.checkRows(c, rows, "root@localhost") - rows = dbt.mustQuery("show grants") - cli.checkRows(c, rows, "GRANT ALL PRIVILEGES ON *.* TO 'root'@'%' WITH GRANT OPTION") - dbt.mustQuery("CREATE USER user1@'%'") - dbt.mustQuery("GRANT SELECT ON test.* TO user1@'%'") + func(dbt *testkit.DBTestKit) { + rows := dbt.MustQuery("select user()") + cli.checkRows(t, rows, "root@localhost") + rows = dbt.MustQuery("show grants") + cli.checkRows(t, rows, "GRANT ALL PRIVILEGES ON *.* TO 'root'@'%' WITH GRANT OPTION") + dbt.MustQuery("CREATE USER user1@'%'") + dbt.MustQuery("GRANT SELECT ON test.* TO user1@'%'") }) // Test with Network interface connection with all hosts - cli.runTests(c, func(config *mysql.Config) { + cli.runTests(t, func(config *mysql.Config) { config.User = "user1" config.DBName = "test" }, - func(dbt *DBTest) { - rows := dbt.mustQuery("select user()") + func(dbt *testkit.DBTestKit) { + rows := dbt.MustQuery("select user()") // NOTICE: this is not compatible with MySQL! (MySQL would report user1@localhost also for 127.0.0.1) - cli.checkRows(c, rows, "user1@127.0.0.1") - rows = dbt.mustQuery("show grants") - cli.checkRows(c, rows, "GRANT USAGE ON *.* TO 'user1'@'%'\nGRANT SELECT ON test.* TO 'user1'@'%'") + cli.checkRows(t, rows, "user1@127.0.0.1") + rows = dbt.MustQuery("show grants") + cli.checkRows(t, rows, "GRANT USAGE ON *.* TO 'user1'@'%'\nGRANT SELECT ON test.* TO 'user1'@'%'") }) // Test with unix domain socket file connection with all hosts - cli.runTests(c, func(config *mysql.Config) { + cli.runTests(t, func(config *mysql.Config) { config.Net = "unix" config.Addr = socketFile config.User = "user1" config.DBName = "test" }, - func(dbt *DBTest) { - rows := dbt.mustQuery("select user()") - cli.checkRows(c, rows, "user1@localhost") - rows = dbt.mustQuery("show grants") - cli.checkRows(c, rows, "GRANT USAGE ON *.* TO 'user1'@'%'\nGRANT SELECT ON test.* TO 'user1'@'%'") + func(dbt *testkit.DBTestKit) { + rows := dbt.MustQuery("select user()") + cli.checkRows(t, rows, "user1@localhost") + rows = dbt.MustQuery("show grants") + cli.checkRows(t, rows, "GRANT USAGE ON *.* TO 'user1'@'%'\nGRANT SELECT ON test.* TO 'user1'@'%'") }) // Setup user1@127.0.0.1 for loop back network interface access - cli.runTests(c, func(config *mysql.Config) { + cli.runTests(t, func(config *mysql.Config) { config.User = "root" config.DBName = "test" }, - func(dbt *DBTest) { - rows := dbt.mustQuery("select user()") + func(dbt *testkit.DBTestKit) { + rows := dbt.MustQuery("select user()") // NOTICE: this is not compatible with MySQL! (MySQL would report user1@localhost also for 127.0.0.1) - cli.checkRows(c, rows, "root@127.0.0.1") - rows = dbt.mustQuery("show grants") - cli.checkRows(c, rows, "GRANT ALL PRIVILEGES ON *.* TO 'root'@'%' WITH GRANT OPTION") - dbt.mustQuery("CREATE USER user1@127.0.0.1") - dbt.mustQuery("GRANT SELECT,INSERT ON test.* TO user1@'127.0.0.1'") + cli.checkRows(t, rows, "root@127.0.0.1") + rows = dbt.MustQuery("show grants") + cli.checkRows(t, rows, "GRANT ALL PRIVILEGES ON *.* TO 'root'@'%' WITH GRANT OPTION") + dbt.MustQuery("CREATE USER user1@127.0.0.1") + dbt.MustQuery("GRANT SELECT,INSERT ON test.* TO user1@'127.0.0.1'") }) // Test with Network interface connection with all hosts - cli.runTests(c, func(config *mysql.Config) { + cli.runTests(t, func(config *mysql.Config) { config.User = "user1" config.DBName = "test" }, - func(dbt *DBTest) { - rows := dbt.mustQuery("select user()") + func(dbt *testkit.DBTestKit) { + rows := dbt.MustQuery("select user()") // NOTICE: this is not compatible with MySQL! (MySQL would report user1@localhost also for 127.0.0.1) - cli.checkRows(c, rows, "user1@127.0.0.1") - rows = dbt.mustQuery("show grants") - cli.checkRows(c, rows, "GRANT USAGE ON *.* TO 'user1'@'127.0.0.1'\nGRANT SELECT,INSERT ON test.* TO 'user1'@'127.0.0.1'") + cli.checkRows(t, rows, "user1@127.0.0.1") + rows = dbt.MustQuery("show grants") + cli.checkRows(t, rows, "GRANT USAGE ON *.* TO 'user1'@'127.0.0.1'\nGRANT SELECT,INSERT ON test.* TO 'user1'@'127.0.0.1'") }) // Test with unix domain socket file connection with all hosts - cli.runTests(c, func(config *mysql.Config) { + cli.runTests(t, func(config *mysql.Config) { config.Net = "unix" config.Addr = socketFile config.User = "user1" config.DBName = "test" }, - func(dbt *DBTest) { - rows := dbt.mustQuery("select user()") - cli.checkRows(c, rows, "user1@localhost") - rows = dbt.mustQuery("show grants") - cli.checkRows(c, rows, "GRANT USAGE ON *.* TO 'user1'@'%'\nGRANT SELECT ON test.* TO 'user1'@'%'") + func(dbt *testkit.DBTestKit) { + rows := dbt.MustQuery("select user()") + cli.checkRows(t, rows, "user1@localhost") + rows = dbt.MustQuery("show grants") + cli.checkRows(t, rows, "GRANT USAGE ON *.* TO 'user1'@'%'\nGRANT SELECT ON test.* TO 'user1'@'%'") }) // Setup user1@localhost for socket (and if MySQL compatible; loop back network interface access) - cli.runTests(c, func(config *mysql.Config) { + cli.runTests(t, func(config *mysql.Config) { config.Net = "unix" config.Addr = socketFile config.User = "root" config.DBName = "test" }, - func(dbt *DBTest) { - rows := dbt.mustQuery("select user()") - cli.checkRows(c, rows, "root@localhost") - rows = dbt.mustQuery("show grants") - cli.checkRows(c, rows, "GRANT ALL PRIVILEGES ON *.* TO 'root'@'%' WITH GRANT OPTION") - dbt.mustQuery("CREATE USER user1@localhost") - dbt.mustQuery("GRANT SELECT,INSERT,UPDATE,DELETE ON test.* TO user1@localhost") + func(dbt *testkit.DBTestKit) { + rows := dbt.MustQuery("select user()") + cli.checkRows(t, rows, "root@localhost") + rows = dbt.MustQuery("show grants") + cli.checkRows(t, rows, "GRANT ALL PRIVILEGES ON *.* TO 'root'@'%' WITH GRANT OPTION") + dbt.MustQuery("CREATE USER user1@localhost") + dbt.MustQuery("GRANT SELECT,INSERT,UPDATE,DELETE ON test.* TO user1@localhost") }) // Test with Network interface connection with all hosts - cli.runTests(c, func(config *mysql.Config) { + cli.runTests(t, func(config *mysql.Config) { config.User = "user1" config.DBName = "test" }, - func(dbt *DBTest) { - rows := dbt.mustQuery("select user()") + func(dbt *testkit.DBTestKit) { + rows := dbt.MustQuery("select user()") // NOTICE: this is not compatible with MySQL! (MySQL would report user1@localhost also for 127.0.0.1) - cli.checkRows(c, rows, "user1@127.0.0.1") - rows = dbt.mustQuery("show grants") - cli.checkRows(c, rows, "GRANT USAGE ON *.* TO 'user1'@'127.0.0.1'\nGRANT SELECT,INSERT ON test.* TO 'user1'@'127.0.0.1'") + cli.checkRows(t, rows, "user1@127.0.0.1") + rows = dbt.MustQuery("show grants") + cli.checkRows(t, rows, "GRANT USAGE ON *.* TO 'user1'@'127.0.0.1'\nGRANT SELECT,INSERT ON test.* TO 'user1'@'127.0.0.1'") }) // Test with unix domain socket file connection with all hosts - cli.runTests(c, func(config *mysql.Config) { + cli.runTests(t, func(config *mysql.Config) { config.Net = "unix" config.Addr = socketFile config.User = "user1" config.DBName = "test" }, - func(dbt *DBTest) { - rows := dbt.mustQuery("select user()") - cli.checkRows(c, rows, "user1@localhost") - rows = dbt.mustQuery("show grants") - cli.checkRows(c, rows, "GRANT USAGE ON *.* TO 'user1'@'localhost'\nGRANT SELECT,INSERT,UPDATE,DELETE ON test.* TO 'user1'@'localhost'") + func(dbt *testkit.DBTestKit) { + rows := dbt.MustQuery("select user()") + cli.checkRows(t, rows, "user1@localhost") + rows = dbt.MustQuery("show grants") + cli.checkRows(t, rows, "GRANT USAGE ON *.* TO 'user1'@'localhost'\nGRANT SELECT,INSERT,UPDATE,DELETE ON test.* TO 'user1'@'localhost'") }) } From fa9acdf60b39612586a741addeb40e59a766d966 Mon Sep 17 00:00:00 2001 From: yedamo Date: Sat, 20 Nov 2021 16:33:56 +0800 Subject: [PATCH 48/55] server: migrate `TestSocket` --- server/main_test.go | 3 +++ server/tidb_test.go | 29 ++++++++--------------------- 2 files changed, 11 insertions(+), 21 deletions(-) diff --git a/server/main_test.go b/server/main_test.go index 4c3969da1f506..e4e7e9bf8d1e4 100644 --- a/server/main_test.go +++ b/server/main_test.go @@ -21,6 +21,7 @@ import ( "testing" "github.com/pingcap/tidb/config" + "github.com/pingcap/tidb/metrics" "github.com/pingcap/tidb/session" "github.com/pingcap/tidb/util/testbridge" "github.com/tikv/client-go/v2/tikv" @@ -36,6 +37,8 @@ func TestMain(m *testing.M) { tikv.EnableFailpoints() + metrics.RegisterMetrics() + // sanity check: the global config should not be changed by other pkg init function. // see also https://github.com/pingcap/tidb/issues/22162 defaultConfig := config.NewConfig() diff --git a/server/tidb_test.go b/server/tidb_test.go index e426ce8d28963..97d5bdcf79355 100644 --- a/server/tidb_test.go +++ b/server/tidb_test.go @@ -42,7 +42,6 @@ import ( "github.com/pingcap/tidb/config" "github.com/pingcap/tidb/domain" "github.com/pingcap/tidb/kv" - "github.com/pingcap/tidb/metrics" "github.com/pingcap/tidb/parser" tmysql "github.com/pingcap/tidb/parser/mysql" "github.com/pingcap/tidb/session" @@ -112,10 +111,6 @@ func createTiDBTestBase(t *testing.T) (*tidbTestBase, func()) { return ts, cleanup } -type tidbTestSuite struct { - *tidbTestSuiteBase -} - type tidbTest struct { *tidbTestBase } @@ -145,18 +140,8 @@ func newTiDBTestSuiteBase() *tidbTestSuiteBase { } } -var _ = Suite(&tidbTestSuite{newTiDBTestSuiteBase()}) var _ = SerialSuites(&tidbTestTopSQLSuite{newTiDBTestSuiteBase()}) -func (ts *tidbTestSuite) SetUpSuite(c *C) { - metrics.RegisterMetrics() - ts.tidbTestSuiteBase.SetUpSuite(c) -} - -func (ts *tidbTestSuite) TearDownSuite(c *C) { - ts.tidbTestSuiteBase.TearDownSuite(c) -} - func (ts *tidbTestTopSQLSuite) SetUpSuite(c *C) { ts.tidbTestSuiteBase.SetUpSuite(c) @@ -504,7 +489,7 @@ func TestSocketForwarding(t *testing.T) { }, "SocketRegression") } -func (ts *tidbTestSuite) TestSocket(c *C) { +func TestSocket(t *testing.T) { cfg := newTestConfig() cfg.Socket = "/tmp/tidbtest.sock" cfg.Port = 0 @@ -512,25 +497,27 @@ func (ts *tidbTestSuite) TestSocket(c *C) { cfg.Host = "" cfg.Status.ReportStatus = false + ts, cleanup := createTiDBTest(t) + defer cleanup() + server, err := NewServer(cfg, ts.tidbdrv) - c.Assert(err, IsNil) + require.NoError(t, err) go func() { err := server.Run() - c.Assert(err, IsNil) + require.NoError(t, err) }() time.Sleep(time.Millisecond * 100) defer server.Close() // a fake server client, config is override, just used to run tests - cli := newTestServerClient() - cli.runTestRegression(c, func(config *mysql.Config) { + cli := newTestingServerClient() + cli.runTestRegression(t, func(config *mysql.Config) { config.User = "root" config.Net = "unix" config.Addr = "/tmp/tidbtest.sock" config.DBName = "test" config.Params = map[string]string{"sql_mode": "STRICT_ALL_TABLES"} }, "SocketRegression") - } func TestSocketAndIp(t *testing.T) { From ac26eb824b39e33db76b8e067f80b1e7d3435c80 Mon Sep 17 00:00:00 2001 From: yedamo Date: Sat, 20 Nov 2021 16:47:10 +0800 Subject: [PATCH 49/55] server: make tests run parallel --- server/tidb_test.go | 23 +++++++++++++++++++---- 1 file changed, 19 insertions(+), 4 deletions(-) diff --git a/server/tidb_test.go b/server/tidb_test.go index 97d5bdcf79355..800743f6b6668 100644 --- a/server/tidb_test.go +++ b/server/tidb_test.go @@ -460,12 +460,18 @@ func TestMultiStatements(t *testing.T) { func TestSocketForwarding(t *testing.T) { t.Parallel() + osTempDir := os.TempDir() + tempDir, err := os.MkdirTemp(osTempDir, "tidb-test.*.socket") + require.NoError(t, err) + socketFile := tempDir + "/tidbtest.sock" // Unix Socket does not work on Windows, so '/' should be OK + defer os.RemoveAll(tempDir) + ts, cleanup := createTiDBTest(t) defer cleanup() cli := newTestingServerClient() cfg := newTestConfig() - cfg.Socket = "/tmp/tidbtest.sock" + cfg.Socket = socketFile cfg.Port = cli.port os.Remove(cfg.Socket) cfg.Status.ReportStatus = false @@ -483,15 +489,22 @@ func TestSocketForwarding(t *testing.T) { cli.runTestRegression(t, func(config *mysql.Config) { config.User = "root" config.Net = "unix" - config.Addr = "/tmp/tidbtest.sock" + config.Addr = socketFile config.DBName = "test" config.Params = map[string]string{"sql_mode": "'STRICT_ALL_TABLES'"} }, "SocketRegression") } func TestSocket(t *testing.T) { + t.Parallel() + osTempDir := os.TempDir() + tempDir, err := os.MkdirTemp(osTempDir, "tidb-test.*.socket") + require.NoError(t, err) + socketFile := tempDir + "/tidbtest.sock" // Unix Socket does not work on Windows, so '/' should be OK + defer os.RemoveAll(tempDir) + cfg := newTestConfig() - cfg.Socket = "/tmp/tidbtest.sock" + cfg.Socket = socketFile cfg.Port = 0 os.Remove(cfg.Socket) cfg.Host = "" @@ -514,13 +527,14 @@ func TestSocket(t *testing.T) { cli.runTestRegression(t, func(config *mysql.Config) { config.User = "root" config.Net = "unix" - config.Addr = "/tmp/tidbtest.sock" + config.Addr = socketFile config.DBName = "test" config.Params = map[string]string{"sql_mode": "STRICT_ALL_TABLES"} }, "SocketRegression") } func TestSocketAndIp(t *testing.T) { + t.Parallel() osTempDir := os.TempDir() tempDir, err := os.MkdirTemp(osTempDir, "tidb-test.*.socket") require.NoError(t, err) @@ -683,6 +697,7 @@ func TestSocketAndIp(t *testing.T) { // TestOnlySocket for server configuration without network interface for mysql clients func TestOnlySocket(t *testing.T) { + t.Parallel() osTempDir := os.TempDir() tempDir, err := os.MkdirTemp(osTempDir, "tidb-test.*.socket") require.NoError(t, err) From 495d440a822542fe302118fa1e54e3022c943487 Mon Sep 17 00:00:00 2001 From: yedamo Date: Sat, 20 Nov 2021 17:09:36 +0800 Subject: [PATCH 50/55] server: migrate `TestTopSQLCPUProfile` --- server/tidb_test.go | 118 ++++++++++++++++++++++++++++++-------------- 1 file changed, 82 insertions(+), 36 deletions(-) diff --git a/server/tidb_test.go b/server/tidb_test.go index 800743f6b6668..f8e1dcef9806d 100644 --- a/server/tidb_test.go +++ b/server/tidb_test.go @@ -117,11 +117,36 @@ type tidbTest struct { func createTiDBTest(t *testing.T) (*tidbTest, func()) { base, cleanup := createTiDBTestBase(t) - // TODO: register metrics - // metrics.RegisterMetrics() return &tidbTest{base}, cleanup } +type tidbTestTopSQL struct { + *tidbTestBase +} + +func createTiDBTestTopSQL(t *testing.T) (*tidbTestTopSQL, func()) { + base, cleanup := createTiDBTestBase(t) + + ts := &tidbTestTopSQL{base} + + // Initialize global variable for top-sql test. + db, err := sql.Open("mysql", ts.getDSN()) + require.NoErrorf(t, err, "Error connecting") + defer func() { + err := db.Close() + require.NoError(t, err) + }() + + dbt := testkit.NewDBTestKit(t, db) + dbt.MustExec("set @@global.tidb_top_sql_precision_seconds=1;") + dbt.MustExec("set @@global.tidb_top_sql_report_interval_seconds=2;") + dbt.MustExec("set @@global.tidb_top_sql_max_statement_count=5;") + + tracecpu.GlobalSQLCPUProfiler.Run() + + return ts, cleanup +} + type tidbTestTopSQLSuite struct { *tidbTestSuiteBase } @@ -1345,39 +1370,42 @@ type collectorWrapper struct { reporter.TopSQLReporter } -func (ts *tidbTestTopSQLSuite) TestTopSQLCPUProfile(c *C) { +func TestTopSQLCPUProfile(t *testing.T) { + ts, cleanup := createTiDBTestTopSQL(t) + defer cleanup() + db, err := sql.Open("mysql", ts.getDSN()) - c.Assert(err, IsNil, Commentf("Error connecting")) + require.NoErrorf(t, err, "Error connecting") defer func() { err := db.Close() - c.Assert(err, IsNil) + require.NoError(t, err) }() - c.Assert(failpoint.Enable("github.com/pingcap/tidb/domain/skipLoadSysVarCacheLoop", `return(true)`), IsNil) - c.Assert(failpoint.Enable("github.com/pingcap/tidb/util/topsql/mockHighLoadForEachSQL", `return(true)`), IsNil) + require.NoError(t, failpoint.Enable("github.com/pingcap/tidb/domain/skipLoadSysVarCacheLoop", `return(true)`)) + require.NoError(t, failpoint.Enable("github.com/pingcap/tidb/util/topsql/mockHighLoadForEachSQL", `return(true)`)) defer func() { err = failpoint.Disable("github.com/pingcap/tidb/domain/skipLoadSysVarCacheLoop") - c.Assert(err, IsNil) + require.NoError(t, err) err = failpoint.Disable("github.com/pingcap/tidb/util/topsql/mockHighLoadForEachSQL") - c.Assert(err, IsNil) + require.NoError(t, err) }() collector := mockTopSQLTraceCPU.NewTopSQLCollector() tracecpu.GlobalSQLCPUProfiler.SetCollector(&collectorWrapper{collector}) - dbt := &DBTest{c, db} - dbt.mustExec("drop database if exists topsql") - dbt.mustExec("create database topsql") - dbt.mustExec("use topsql;") - dbt.mustExec("create table t (a int auto_increment, b int, unique index idx(a));") - dbt.mustExec("create table t1 (a int auto_increment, b int, unique index idx(a));") - dbt.mustExec("create table t2 (a int auto_increment, b int, unique index idx(a));") - dbt.mustExec("set @@global.tidb_enable_top_sql='On';") + dbt := testkit.NewDBTestKit(t, db) + dbt.MustExec("drop database if exists topsql") + dbt.MustExec("create database topsql") + dbt.MustExec("use topsql;") + dbt.MustExec("create table t (a int auto_increment, b int, unique index idx(a));") + dbt.MustExec("create table t1 (a int auto_increment, b int, unique index idx(a));") + dbt.MustExec("create table t2 (a int auto_increment, b int, unique index idx(a));") + dbt.MustExec("set @@global.tidb_enable_top_sql='On';") config.UpdateGlobal(func(conf *config.Config) { conf.TopSQL.ReceiverAddress = "127.0.0.1:4001" }) - dbt.mustExec("set @@global.tidb_top_sql_precision_seconds=1;") - dbt.mustExec("set @@global.tidb_txn_mode = 'pessimistic'") + dbt.MustExec("set @@global.tidb_top_sql_precision_seconds=1;") + dbt.MustExec("set @@global.tidb_txn_mode = 'pessimistic'") // Test case 1: DML query: insert/update/replace/delete/select cases1 := []struct { @@ -1401,10 +1429,10 @@ func (ts *tidbTestTopSQLSuite) TestTopSQLCPUProfile(c *C) { ctx, cancel := context.WithCancel(context.Background()) cases1[i].cancel = cancel sqlStr := ca.sql - go ts.loopExec(ctx, c, func(db *sql.DB) { - dbt := &DBTest{c, db} + go ts.loopExec(ctx, t, func(db *sql.DB) { + dbt := testkit.NewDBTestKit(t, db) if strings.HasPrefix(sqlStr, "select") { - rows := dbt.mustQuery(sqlStr) + rows := dbt.MustQuery(sqlStr) for rows.Next() { } } else { @@ -1417,25 +1445,24 @@ func (ts *tidbTestTopSQLSuite) TestTopSQLCPUProfile(c *C) { timeoutCtx, cancel := context.WithTimeout(context.Background(), time.Second*20) defer cancel() checkFn := func(sql, planRegexp string) { - c.Assert(timeoutCtx.Err(), IsNil) - commentf := Commentf("sql: %v", sql) + require.NoError(t, timeoutCtx.Err()) stats := collector.GetSQLStatsBySQLWithRetry(sql, len(planRegexp) > 0) // since 1 sql may has many plan, check `len(stats) > 0` instead of `len(stats) == 1`. - c.Assert(len(stats) > 0, IsTrue, commentf) + require.Greaterf(t, len(stats), 0, "sql: %v", sql) for _, s := range stats { sqlStr := collector.GetSQL(s.SQLDigest) encodedPlan := collector.GetPlan(s.PlanDigest) // Normalize the user SQL before check. normalizedSQL := parser.Normalize(sql) - c.Assert(sqlStr, Equals, normalizedSQL, commentf) + require.Equalf(t, normalizedSQL, sqlStr, "sql: %v", sql) // decode plan before check. normalizedPlan, err := plancodec.DecodeNormalizedPlan(encodedPlan) - c.Assert(err, IsNil) + require.NoError(t, err) // remove '\n' '\t' before do regexp match. normalizedPlan = strings.Replace(normalizedPlan, "\n", " ", -1) normalizedPlan = strings.Replace(normalizedPlan, "\t", " ", -1) - c.Assert(normalizedPlan, Matches, planRegexp, commentf) + require.Regexpf(t, planRegexp, normalizedPlan, "sql: %v", sql) } } // Wait the top sql collector to collect profile data. @@ -1470,14 +1497,14 @@ func (ts *tidbTestTopSQLSuite) TestTopSQLCPUProfile(c *C) { cases2[i].cancel = cancel prepare, args := ca.prepare, ca.args var stmt *sql.Stmt - go ts.loopExec(ctx, c, func(db *sql.DB) { + go ts.loopExec(ctx, t, func(db *sql.DB) { if stmt == nil { stmt, err = db.Prepare(prepare) - c.Assert(err, IsNil) + require.NoError(t, err) } if strings.HasPrefix(prepare, "select") { rows, err := stmt.Query(args...) - c.Assert(err, IsNil) + require.NoError(t, err) for rows.Next() { } } else { @@ -1517,17 +1544,17 @@ func (ts *tidbTestTopSQLSuite) TestTopSQLCPUProfile(c *C) { cases3[i].cancel = cancel prepare, args := ca.prepare, ca.args doPrepare := true - go ts.loopExec(ctx, c, func(db *sql.DB) { + go ts.loopExec(ctx, t, func(db *sql.DB) { if doPrepare { doPrepare = false _, err := db.Exec(fmt.Sprintf("prepare stmt from '%v'", prepare)) - c.Assert(err, IsNil) + require.NoError(t, err) } sqlBuf := bytes.NewBuffer(nil) sqlBuf.WriteString("execute stmt ") for i := range args { _, err = db.Exec(fmt.Sprintf("set @%c=%v", 'a'+i, args[i])) - c.Assert(err, IsNil) + require.NoError(t, err) if i == 0 { sqlBuf.WriteString("using ") } else { @@ -1538,7 +1565,7 @@ func (ts *tidbTestTopSQLSuite) TestTopSQLCPUProfile(c *C) { } if strings.HasPrefix(prepare, "select") { rows, err := db.Query(sqlBuf.String()) - c.Assert(err, IsNil, Commentf("%v", sqlBuf.String())) + require.NoErrorf(t, err, "%v", sqlBuf.String()) for rows.Next() { } } else { @@ -1559,7 +1586,7 @@ func (ts *tidbTestTopSQLSuite) TestTopSQLCPUProfile(c *C) { // Test case 4: transaction commit ctx4, cancel4 := context.WithCancel(context.Background()) defer cancel4() - go ts.loopExec(ctx4, c, func(db *sql.DB) { + go ts.loopExec(ctx4, t, func(db *sql.DB) { db.Exec("begin") db.Exec("insert into t () values (),(),(),(),(),(),(),(),(),(),(),(),(),(),(),(),(),(),(),(),(),()") db.Exec("commit") @@ -1738,3 +1765,22 @@ func (ts *tidbTestTopSQLSuite) loopExec(ctx context.Context, c *C, fn func(db *s fn(db) } } + +func (ts *tidbTestTopSQL) loopExec(ctx context.Context, t *testing.T, fn func(db *sql.DB)) { + db, err := sql.Open("mysql", ts.getDSN()) + require.NoError(t, err, "Error connecting") + defer func() { + err := db.Close() + require.NoError(t, err) + }() + dbt := testkit.NewDBTestKit(t, db) + dbt.MustExec("use topsql;") + for { + select { + case <-ctx.Done(): + return + default: + } + fn(db) + } +} From b94de583ba05afd1d378352c2e70cef40a1f7a3f Mon Sep 17 00:00:00 2001 From: yedamo Date: Sat, 20 Nov 2021 17:23:02 +0800 Subject: [PATCH 51/55] server: migrate `TestTopSQLAgent` --- server/tidb_test.go | 69 +++++++++++++++++++++++---------------------- 1 file changed, 36 insertions(+), 33 deletions(-) diff --git a/server/tidb_test.go b/server/tidb_test.go index f8e1dcef9806d..f048798438d49 100644 --- a/server/tidb_test.go +++ b/server/tidb_test.go @@ -1595,40 +1595,43 @@ func TestTopSQLCPUProfile(t *testing.T) { checkFn("commit", "") } -func (ts *tidbTestTopSQLSuite) TestTopSQLAgent(c *C) { - c.Skip("unstable, skip it and fix it before 20210702") +func TestTopSQLAgent(t *testing.T) { + t.Skip("unstable, skip it and fix it before 20210702") + + ts, cleanup := createTiDBTestTopSQL(t) + defer cleanup() db, err := sql.Open("mysql", ts.getDSN()) - c.Assert(err, IsNil, Commentf("Error connecting")) + require.NoError(t, err, "Error connecting") defer func() { err := db.Close() - c.Assert(err, IsNil) + require.NoError(t, err) }() agentServer, err := mockTopSQLReporter.StartMockAgentServer() - c.Assert(err, IsNil) + require.NoError(t, err) defer func() { agentServer.Stop() }() - c.Assert(failpoint.Enable("github.com/pingcap/tidb/util/topsql/reporter/resetTimeoutForTest", `return(true)`), IsNil) - c.Assert(failpoint.Enable("github.com/pingcap/tidb/domain/skipLoadSysVarCacheLoop", `return(true)`), IsNil) - c.Assert(failpoint.Enable("github.com/pingcap/tidb/util/topsql/mockHighLoadForEachSQL", `return(true)`), IsNil) + require.NoError(t, failpoint.Enable("github.com/pingcap/tidb/util/topsql/reporter/resetTimeoutForTest", `return(true)`)) + require.NoError(t, failpoint.Enable("github.com/pingcap/tidb/domain/skipLoadSysVarCacheLoop", `return(true)`)) + require.NoError(t, failpoint.Enable("github.com/pingcap/tidb/util/topsql/mockHighLoadForEachSQL", `return(true)`)) defer func() { err := failpoint.Disable("github.com/pingcap/tidb/util/topsql/reporter/resetTimeoutForTest") - c.Assert(err, IsNil) + require.NoError(t, err) err = failpoint.Disable("github.com/pingcap/tidb/domain/skipLoadSysVarCacheLoop") - c.Assert(err, IsNil) + require.NoError(t, err) err = failpoint.Disable("github.com/pingcap/tidb/util/topsql/mockHighLoadForEachSQL") - c.Assert(err, IsNil) + require.NoError(t, err) }() - dbt := &DBTest{c, db} - dbt.mustExec("drop database if exists topsql") - dbt.mustExec("create database topsql") - dbt.mustExec("use topsql;") + dbt := testkit.NewDBTestKit(t, db) + dbt.MustExec("drop database if exists topsql") + dbt.MustExec("create database topsql") + dbt.MustExec("use topsql;") for i := 0; i < 20; i++ { - dbt.mustExec(fmt.Sprintf("create table t%v (a int auto_increment, b int, unique index idx(a));", i)) + dbt.MustExec(fmt.Sprintf("create table t%v (a int auto_increment, b int, unique index idx(a));", i)) for j := 0; j < 100; j++ { - dbt.mustExec(fmt.Sprintf("insert into t%v (b) values (%v);", i, j)) + dbt.MustExec(fmt.Sprintf("insert into t%v (b) values (%v);", i, j)) } } setTopSQLReceiverAddress := func(addr string) { @@ -1636,11 +1639,11 @@ func (ts *tidbTestTopSQLSuite) TestTopSQLAgent(c *C) { conf.TopSQL.ReceiverAddress = addr }) } - dbt.mustExec("set @@global.tidb_enable_top_sql='On';") + dbt.MustExec("set @@global.tidb_enable_top_sql='On';") setTopSQLReceiverAddress("") - dbt.mustExec("set @@global.tidb_top_sql_precision_seconds=1;") - dbt.mustExec("set @@global.tidb_top_sql_report_interval_seconds=2;") - dbt.mustExec("set @@global.tidb_top_sql_max_statement_count=5;") + dbt.MustExec("set @@global.tidb_top_sql_precision_seconds=1;") + dbt.MustExec("set @@global.tidb_top_sql_report_interval_seconds=2;") + dbt.MustExec("set @@global.tidb_top_sql_max_statement_count=5;") r := reporter.NewRemoteTopSQLReporter(reporter.NewGRPCReportClient(plancodec.DecodeNormalizedPlan)) tracecpu.GlobalSQLCPUProfiler.SetCollector(&collectorWrapper{r}) @@ -1648,28 +1651,28 @@ func (ts *tidbTestTopSQLSuite) TestTopSQLAgent(c *C) { // TODO: change to ensure that the right sql statements are reported, not just counts checkFn := func(n int) { records := agentServer.GetLatestRecords() - c.Assert(len(records), Equals, n) + require.Len(t, records, n) for _, r := range records { sqlMeta, exist := agentServer.GetSQLMetaByDigestBlocking(r.SqlDigest, time.Second) - c.Assert(exist, IsTrue) - c.Check(sqlMeta.NormalizedSql, Matches, "select.*from.*join.*") + require.True(t, exist) + require.Regexp(t, "select.*from.*join.*", sqlMeta.NormalizedSql) if len(r.PlanDigest) == 0 { continue } plan, exist := agentServer.GetPlanMetaByDigestBlocking(r.PlanDigest, time.Second) - c.Assert(exist, IsTrue) + require.True(t, exist) plan = strings.Replace(plan, "\n", " ", -1) plan = strings.Replace(plan, "\t", " ", -1) - c.Assert(plan, Matches, ".*Join.*Select.*") + require.Regexp(t, ".*Join.*Select.*", plan) } } runWorkload := func(start, end int) context.CancelFunc { ctx, cancel := context.WithCancel(context.Background()) for i := start; i < end; i++ { query := fmt.Sprintf("select /*+ HASH_JOIN(ta, tb) */ * from t%[1]v ta join t%[1]v tb on ta.a=tb.a where ta.b is not null;", i) - go ts.loopExec(ctx, c, func(db *sql.DB) { - dbt := &DBTest{c, db} - rows := dbt.mustQuery(query) + go ts.loopExec(ctx, t, func(db *sql.DB) { + dbt := testkit.NewDBTestKit(t, db) + rows := dbt.MustQuery(query) for rows.Next() { } }) @@ -1684,12 +1687,12 @@ func (ts *tidbTestTopSQLSuite) TestTopSQLAgent(c *C) { agentServer.WaitCollectCnt(1, time.Second*4) checkFn(0) // Test after set agent address and the evict take effect. - dbt.mustExec("set @@global.tidb_top_sql_max_statement_count=5;") + dbt.MustExec("set @@global.tidb_top_sql_max_statement_count=5;") setTopSQLReceiverAddress(agentServer.Address()) agentServer.WaitCollectCnt(1, time.Second*4) checkFn(5) // Test with wrong agent address, the agent server can't receive any record. - dbt.mustExec("set @@global.tidb_top_sql_max_statement_count=8;") + dbt.MustExec("set @@global.tidb_top_sql_max_statement_count=8;") setTopSQLReceiverAddress("127.0.0.1:65530") agentServer.WaitCollectCnt(1, time.Second*4) @@ -1703,7 +1706,7 @@ func (ts *tidbTestTopSQLSuite) TestTopSQLAgent(c *C) { // case 2: agent hangs for a while cancel2 := runWorkload(0, 10) // empty agent address, should not collect records - dbt.mustExec("set @@global.tidb_top_sql_max_statement_count=5;") + dbt.MustExec("set @@global.tidb_top_sql_max_statement_count=5;") setTopSQLReceiverAddress("") agentServer.WaitCollectCnt(1, time.Second*4) checkFn(0) @@ -1739,7 +1742,7 @@ func (ts *tidbTestTopSQLSuite) TestTopSQLAgent(c *C) { agentServer.Stop() // agent server restart agentServer, err = mockTopSQLReporter.StartMockAgentServer() - c.Assert(err, IsNil) + require.NoError(t, err) setTopSQLReceiverAddress(agentServer.Address()) // check result agentServer.WaitCollectCnt(2, time.Second*8) From a922edaff083bf233dde4f6fb1c7ebca513b33e5 Mon Sep 17 00:00:00 2001 From: yedamo Date: Sat, 20 Nov 2021 17:31:28 +0800 Subject: [PATCH 52/55] server: cleanup tidb_test.go --- server/tidb_test.go | 104 -------------------------------------------- 1 file changed, 104 deletions(-) diff --git a/server/tidb_test.go b/server/tidb_test.go index f048798438d49..8ade51be762c0 100644 --- a/server/tidb_test.go +++ b/server/tidb_test.go @@ -36,7 +36,6 @@ import ( "time" "github.com/go-sql-driver/mysql" - . "github.com/pingcap/check" "github.com/pingcap/errors" "github.com/pingcap/failpoint" "github.com/pingcap/tidb/config" @@ -147,90 +146,6 @@ func createTiDBTestTopSQL(t *testing.T) (*tidbTestTopSQL, func()) { return ts, cleanup } -type tidbTestTopSQLSuite struct { - *tidbTestSuiteBase -} - -type tidbTestSuiteBase struct { - *testServerClient - tidbdrv *TiDBDriver - server *Server - domain *domain.Domain - store kv.Storage -} - -func newTiDBTestSuiteBase() *tidbTestSuiteBase { - return &tidbTestSuiteBase{ - testServerClient: newTestServerClient(), - } -} - -var _ = SerialSuites(&tidbTestTopSQLSuite{newTiDBTestSuiteBase()}) - -func (ts *tidbTestTopSQLSuite) SetUpSuite(c *C) { - ts.tidbTestSuiteBase.SetUpSuite(c) - - // Initialize global variable for top-sql test. - db, err := sql.Open("mysql", ts.getDSN()) - c.Assert(err, IsNil, Commentf("Error connecting")) - defer func() { - err := db.Close() - c.Assert(err, IsNil) - }() - - dbt := &DBTest{c, db} - dbt.mustExec("set @@global.tidb_top_sql_precision_seconds=1;") - dbt.mustExec("set @@global.tidb_top_sql_report_interval_seconds=2;") - dbt.mustExec("set @@global.tidb_top_sql_max_statement_count=5;") - - tracecpu.GlobalSQLCPUProfiler.Run() -} - -func (ts *tidbTestTopSQLSuite) TearDownSuite(c *C) { - ts.tidbTestSuiteBase.TearDownSuite(c) -} - -func (ts *tidbTestSuiteBase) SetUpSuite(c *C) { - var err error - ts.store, err = mockstore.NewMockStore() - session.DisableStats4Test() - c.Assert(err, IsNil) - ts.domain, err = session.BootstrapSession(ts.store) - c.Assert(err, IsNil) - ts.tidbdrv = NewTiDBDriver(ts.store) - cfg := newTestConfig() - cfg.Socket = "" - cfg.Port = ts.port - cfg.Status.ReportStatus = true - cfg.Status.StatusPort = ts.statusPort - cfg.Performance.TCPKeepAlive = true - err = logutil.InitLogger(cfg.Log.ToLogConfig()) - c.Assert(err, IsNil) - - server, err := NewServer(cfg, ts.tidbdrv) - c.Assert(err, IsNil) - ts.port = getPortFromTCPAddr(server.listener.Addr()) - ts.statusPort = getPortFromTCPAddr(server.statusListener.Addr()) - ts.server = server - go func() { - err := ts.server.Run() - c.Assert(err, IsNil) - }() - ts.waitUntilServerOnline() -} - -func (ts *tidbTestSuiteBase) TearDownSuite(c *C) { - if ts.store != nil { - ts.store.Close() - } - if ts.domain != nil { - ts.domain.Close() - } - if ts.server != nil { - ts.server.Close() - } -} - func TestRegression(t *testing.T) { ts, cleanup := createTiDBTest(t) defer cleanup() @@ -1750,25 +1665,6 @@ func TestTopSQLAgent(t *testing.T) { cancel5() } -func (ts *tidbTestTopSQLSuite) loopExec(ctx context.Context, c *C, fn func(db *sql.DB)) { - db, err := sql.Open("mysql", ts.getDSN()) - c.Assert(err, IsNil, Commentf("Error connecting")) - defer func() { - err := db.Close() - c.Assert(err, IsNil) - }() - dbt := &DBTest{c, db} - dbt.mustExec("use topsql;") - for { - select { - case <-ctx.Done(): - return - default: - } - fn(db) - } -} - func (ts *tidbTestTopSQL) loopExec(ctx context.Context, t *testing.T, fn func(db *sql.DB)) { db, err := sql.Open("mysql", ts.getDSN()) require.NoError(t, err, "Error connecting") From ad42d715c1b13aaae9dce83de2ccf9325998ed99 Mon Sep 17 00:00:00 2001 From: yedamo Date: Sat, 20 Nov 2021 20:08:25 +0800 Subject: [PATCH 53/55] server: cleanup server_test.go --- server/main_test.go | 2 + server/server.go | 2 - server/server_test.go | 144 ------------------------------------------ 3 files changed, 2 insertions(+), 146 deletions(-) diff --git a/server/main_test.go b/server/main_test.go index e4e7e9bf8d1e4..155d9f9b7294e 100644 --- a/server/main_test.go +++ b/server/main_test.go @@ -31,6 +31,8 @@ import ( func TestMain(m *testing.M) { testbridge.WorkaroundGoCheckFlags() + runInGoTest = true // flag for NewServer to known it is running in test environment + // AsyncCommit will make DDL wait 2.5s before changing to the next state. // Set schema lease to avoid it from making CI slow. session.SetSchemaLease(0) diff --git a/server/server.go b/server/server.go index e656569a14a77..34f71b6ff06e3 100644 --- a/server/server.go +++ b/server/server.go @@ -32,7 +32,6 @@ package server import ( "context" "crypto/tls" - "flag" "fmt" "math/rand" "net" @@ -88,7 +87,6 @@ func init() { if err != nil { osVersion = "" } - runInGoTest = flag.Lookup("test.v") != nil || flag.Lookup("check.v") != nil } var ( diff --git a/server/server_test.go b/server/server_test.go index b8f928f4e0c37..add08f06202ed 100644 --- a/server/server_test.go +++ b/server/server_test.go @@ -32,7 +32,6 @@ import ( "time" "github.com/go-sql-driver/mysql" - . "github.com/pingcap/check" "github.com/pingcap/errors" "github.com/pingcap/failpoint" "github.com/pingcap/log" @@ -49,11 +48,6 @@ var ( regression = true ) -func TestT(t *testing.T) { - CustomVerboseFlag = true - TestingT(t) -} - type configOverrider func(*mysql.Config) // testServerClient config server connect parameters and provider several @@ -137,43 +131,6 @@ func (cli *testingServerClient) runTests(t *testing.T, overrider configOverrider } } -// runTestsOnNewDB runs tests using a specified database which will be created before the test and destroyed after the test. -func (cli *testServerClient) runTestsOnNewDB(c *C, overrider configOverrider, dbName string, tests ...func(dbt *DBTest)) { - dsn := cli.getDSN(overrider, func(config *mysql.Config) { - config.DBName = "" - }) - db, err := sql.Open("mysql", dsn) - c.Assert(err, IsNil, Commentf("Error connecting")) - defer func() { - err := db.Close() - c.Assert(err, IsNil) - }() - - _, err = db.Exec(fmt.Sprintf("DROP DATABASE IF EXISTS `%s`;", dbName)) - if err != nil { - fmt.Println(err) - } - c.Assert(err, IsNil, Commentf("Error drop database %s: %s", dbName, err)) - - _, err = db.Exec(fmt.Sprintf("CREATE DATABASE `%s`;", dbName)) - c.Assert(err, IsNil, Commentf("Error create database %s: %s", dbName, err)) - - defer func() { - _, err = db.Exec(fmt.Sprintf("DROP DATABASE IF EXISTS `%s`;", dbName)) - c.Assert(err, IsNil, Commentf("Error drop database %s: %s", dbName, err)) - }() - - _, err = db.Exec(fmt.Sprintf("USE `%s`;", dbName)) - c.Assert(err, IsNil, Commentf("Error use database %s: %s", dbName, err)) - - dbt := &DBTest{c, db} - for _, test := range tests { - test(dbt) - // to fix : no db selected - _, _ = dbt.db.Exec("DROP TABLE IF EXISTS test") - } -} - // runTestsOnNewDB runs tests using a specified database which will be created before the test and destroyed after the test. func (cli *testingServerClient) runTestsOnNewDB(t *testing.T, overrider configOverrider, dbName string, tests ...func(dbt *testkit.DBTestKit)) { dsn := cli.getDSN(overrider, func(config *mysql.Config) { @@ -211,107 +168,6 @@ func (cli *testingServerClient) runTestsOnNewDB(t *testing.T, overrider configOv } } -type DBTest struct { - *C - db *sql.DB -} - -func (dbt *DBTest) mustExec(query string, args ...interface{}) (res sql.Result) { - res, err := dbt.db.Exec(query, args...) - dbt.Assert(err, IsNil, Commentf("Exec %s", query)) - return res -} - -func (dbt *DBTest) mustQuery(query string, args ...interface{}) (rows *sql.Rows) { - rows, err := dbt.db.Query(query, args...) - dbt.Assert(err, IsNil, Commentf("Query %s", query)) - return rows -} - -func (dbt *DBTest) mustQueryRows(query string, args ...interface{}) { - rows := dbt.mustQuery(query, args...) - dbt.Assert(rows.Next(), IsTrue) - rows.Close() -} - -func (cli *testServerClient) runTestRegression(c *C, overrider configOverrider, dbName string) { - cli.runTestsOnNewDB(c, overrider, dbName, func(dbt *DBTest) { - // Show the user - dbt.mustExec("select user()") - - // Create Table - dbt.mustExec("CREATE TABLE test (val TINYINT)") - - // Test for unexpected data - var out bool - rows := dbt.mustQuery("SELECT * FROM test") - dbt.Assert(rows.Next(), IsFalse, Commentf("unexpected data in empty table")) - - // Create Data - res := dbt.mustExec("INSERT INTO test VALUES (1)") - // res := dbt.mustExec("INSERT INTO test VALUES (?)", 1) - count, err := res.RowsAffected() - dbt.Assert(err, IsNil) - dbt.Check(count, Equals, int64(1)) - id, err := res.LastInsertId() - dbt.Assert(err, IsNil) - dbt.Check(id, Equals, int64(0)) - - // Read - rows = dbt.mustQuery("SELECT val FROM test") - if rows.Next() { - err = rows.Scan(&out) - c.Assert(err, IsNil) - dbt.Check(out, IsTrue) - dbt.Check(rows.Next(), IsFalse, Commentf("unexpected data")) - } else { - dbt.Error("no data") - } - rows.Close() - - // Update - res = dbt.mustExec("UPDATE test SET val = 0 WHERE val = ?", 1) - count, err = res.RowsAffected() - dbt.Assert(err, IsNil) - dbt.Check(count, Equals, int64(1)) - - // Check Update - rows = dbt.mustQuery("SELECT val FROM test") - if rows.Next() { - err = rows.Scan(&out) - c.Assert(err, IsNil) - dbt.Check(out, IsFalse) - dbt.Check(rows.Next(), IsFalse, Commentf("unexpected data")) - } else { - dbt.Error("no data") - } - rows.Close() - - // Delete - res = dbt.mustExec("DELETE FROM test WHERE val = 0") - // res = dbt.mustExec("DELETE FROM test WHERE val = ?", 0) - count, err = res.RowsAffected() - dbt.Assert(err, IsNil) - dbt.Check(count, Equals, int64(1)) - - // Check for unexpected rows - res = dbt.mustExec("DELETE FROM test") - count, err = res.RowsAffected() - dbt.Assert(err, IsNil) - dbt.Check(count, Equals, int64(0)) - - dbt.mustQueryRows("SELECT 1") - - var b = make([]byte, 0) - if err := dbt.db.QueryRow("SELECT ?", b).Scan(&b); err != nil { - dbt.Fatal(err) - } - if b == nil { - dbt.Error("nil echo from non-nil input") - } - }) -} - func (cli *testingServerClient) runTestRegression(t *testing.T, overrider configOverrider, dbName string) { cli.runTestsOnNewDB(t, overrider, dbName, func(dbt *testkit.DBTestKit) { // Show the user From 9e5b1f686c9b818a343c3316a2652349a6aa4163 Mon Sep 17 00:00:00 2001 From: yedamo Date: Sat, 20 Nov 2021 20:25:10 +0800 Subject: [PATCH 54/55] server: polish method's naming --- server/tidb_serial_test.go | 28 +++++------ server/tidb_test.go | 97 +++++++++++++++++--------------------- 2 files changed, 58 insertions(+), 67 deletions(-) diff --git a/server/tidb_serial_test.go b/server/tidb_serial_test.go index 353c9bb44d9c9..9bc6b65198a8d 100644 --- a/server/tidb_serial_test.go +++ b/server/tidb_serial_test.go @@ -36,7 +36,7 @@ import ( // this test will change `kv.TxnTotalSizeLimit` which may affect other test suites, // so we must make it running in serial. func TestLoadData(t *testing.T) { - ts, cleanup := createTiDBTest(t) + ts, cleanup := createTidbTestSuite(t) defer cleanup() ts.runTestLoadData(t, ts.server) @@ -45,7 +45,7 @@ func TestLoadData(t *testing.T) { } func TestConfigDefaultValue(t *testing.T) { - ts, cleanup := createTiDBTest(t) + ts, cleanup := createTidbTestSuite(t) defer cleanup() ts.runTestsOnNewDB(t, nil, "config", func(dbt *testkit.DBTestKit) { @@ -57,35 +57,35 @@ func TestConfigDefaultValue(t *testing.T) { // Fix issue#22540. Change tidb_dml_batch_size, // then check if load data into table with auto random column works properly. func TestLoadDataAutoRandom(t *testing.T) { - ts, cleanup := createTiDBTest(t) + ts, cleanup := createTidbTestSuite(t) defer cleanup() ts.runTestLoadDataAutoRandom(t) } func TestLoadDataAutoRandomWithSpecialTerm(t *testing.T) { - ts, cleanup := createTiDBTest(t) + ts, cleanup := createTidbTestSuite(t) defer cleanup() ts.runTestLoadDataAutoRandomWithSpecialTerm(t) } func TestExplainFor(t *testing.T) { - ts, cleanup := createTiDBTest(t) + ts, cleanup := createTidbTestSuite(t) defer cleanup() ts.runTestExplainForConn(t) } func TestStmtCount(t *testing.T) { - ts, cleanup := createTiDBTest(t) + ts, cleanup := createTidbTestSuite(t) defer cleanup() ts.runTestStmtCount(t) } func TestLoadDataListPartition(t *testing.T) { - ts, cleanup := createTiDBTest(t) + ts, cleanup := createTidbTestSuite(t) defer cleanup() ts.runTestLoadDataForListPartition(t) @@ -95,7 +95,7 @@ func TestLoadDataListPartition(t *testing.T) { } func TestTLSAuto(t *testing.T) { - ts, cleanup := createTiDBTest(t) + ts, cleanup := createTidbTestSuite(t) defer cleanup() // Start the server without TLS configure, letting the server create these as AutoTLS is enabled @@ -126,7 +126,7 @@ func TestTLSAuto(t *testing.T) { } func TestTLSBasic(t *testing.T) { - ts, cleanup := createTiDBTest(t) + ts, cleanup := createTidbTestSuite(t) defer cleanup() // Generate valid TLS certificates. @@ -201,7 +201,7 @@ func TestTLSBasic(t *testing.T) { } func TestTLSVerify(t *testing.T) { - ts, cleanup := createTiDBTest(t) + ts, cleanup := createTidbTestSuite(t) defer cleanup() // Generate valid TLS certificates. @@ -275,7 +275,7 @@ func TestTLSVerify(t *testing.T) { } func TestErrorNoRollback(t *testing.T) { - ts, cleanup := createTiDBTest(t) + ts, cleanup := createTidbTestSuite(t) defer cleanup() // Generate valid TLS certificates. @@ -345,7 +345,7 @@ func TestErrorNoRollback(t *testing.T) { } func TestPrepareCount(t *testing.T) { - ts, cleanup := createTiDBTest(t) + ts, cleanup := createTidbTestSuite(t) defer cleanup() qctx, err := ts.tidbdrv.OpenCtx(uint64(0), 0, uint8(tmysql.DefaultCollationID), "test", nil) @@ -368,7 +368,7 @@ func TestPrepareCount(t *testing.T) { } func TestDefaultCharacterAndCollation(t *testing.T) { - ts, cleanup := createTiDBTest(t) + ts, cleanup := createTidbTestSuite(t) defer cleanup() // issue #21194 @@ -394,7 +394,7 @@ func TestDefaultCharacterAndCollation(t *testing.T) { } func TestReloadTLS(t *testing.T) { - ts, cleanup := createTiDBTest(t) + ts, cleanup := createTidbTestSuite(t) defer cleanup() // Generate valid TLS certificates. diff --git a/server/tidb_test.go b/server/tidb_test.go index 8ade51be762c0..a306ecd7f1f66 100644 --- a/server/tidb_test.go +++ b/server/tidb_test.go @@ -56,7 +56,7 @@ import ( "github.com/stretchr/testify/require" ) -type tidbTestBase struct { +type tidbTestSuite struct { *testingServerClient tidbdrv *TiDBDriver server *Server @@ -64,10 +64,10 @@ type tidbTestBase struct { store kv.Storage } -func createTiDBTestBase(t *testing.T) (*tidbTestBase, func()) { - ts := &tidbTestBase{testingServerClient: newTestingServerClient()} +func createTidbTestSuite(t *testing.T) (*tidbTestSuite, func()) { + ts := &tidbTestSuite{testingServerClient: newTestingServerClient()} - // setup tidbTestBase + // setup tidbTestSuite var err error ts.store, err = mockstore.NewMockStore() session.DisableStats4Test() @@ -110,23 +110,14 @@ func createTiDBTestBase(t *testing.T) (*tidbTestBase, func()) { return ts, cleanup } -type tidbTest struct { - *tidbTestBase +type tidbTestTopSQLSuite struct { + *tidbTestSuite } -func createTiDBTest(t *testing.T) (*tidbTest, func()) { - base, cleanup := createTiDBTestBase(t) - return &tidbTest{base}, cleanup -} - -type tidbTestTopSQL struct { - *tidbTestBase -} - -func createTiDBTestTopSQL(t *testing.T) (*tidbTestTopSQL, func()) { - base, cleanup := createTiDBTestBase(t) +func createTidbTestTopSQLSuite(t *testing.T) (*tidbTestTopSQLSuite, func()) { + base, cleanup := createTidbTestSuite(t) - ts := &tidbTestTopSQL{base} + ts := &tidbTestTopSQLSuite{base} // Initialize global variable for top-sql test. db, err := sql.Open("mysql", ts.getDSN()) @@ -147,7 +138,7 @@ func createTiDBTestTopSQL(t *testing.T) (*tidbTestTopSQL, func()) { } func TestRegression(t *testing.T) { - ts, cleanup := createTiDBTest(t) + ts, cleanup := createTidbTestSuite(t) defer cleanup() if regression { t.Parallel() @@ -157,7 +148,7 @@ func TestRegression(t *testing.T) { func TestUint64(t *testing.T) { t.Parallel() - ts, cleanup := createTiDBTest(t) + ts, cleanup := createTidbTestSuite(t) defer cleanup() ts.runTestPrepareResultFieldType(t) @@ -165,7 +156,7 @@ func TestUint64(t *testing.T) { func TestSpecialType(t *testing.T) { t.Parallel() - ts, cleanup := createTiDBTest(t) + ts, cleanup := createTidbTestSuite(t) defer cleanup() ts.runTestSpecialType(t) @@ -173,7 +164,7 @@ func TestSpecialType(t *testing.T) { func TestPreparedString(t *testing.T) { t.Parallel() - ts, cleanup := createTiDBTest(t) + ts, cleanup := createTidbTestSuite(t) defer cleanup() ts.runTestPreparedString(t) @@ -181,7 +172,7 @@ func TestPreparedString(t *testing.T) { func TestPreparedTimestamp(t *testing.T) { t.Parallel() - ts, cleanup := createTiDBTest(t) + ts, cleanup := createTidbTestSuite(t) defer cleanup() ts.runTestPreparedTimestamp(t) @@ -189,7 +180,7 @@ func TestPreparedTimestamp(t *testing.T) { func TestConcurrentUpdate(t *testing.T) { t.Parallel() - ts, cleanup := createTiDBTest(t) + ts, cleanup := createTidbTestSuite(t) defer cleanup() ts.runTestConcurrentUpdate(t) @@ -197,7 +188,7 @@ func TestConcurrentUpdate(t *testing.T) { func TestErrorCode(t *testing.T) { t.Parallel() - ts, cleanup := createTiDBTest(t) + ts, cleanup := createTidbTestSuite(t) defer cleanup() ts.runTestErrorCode(t) @@ -205,7 +196,7 @@ func TestErrorCode(t *testing.T) { func TestAuth(t *testing.T) { t.Parallel() - ts, cleanup := createTiDBTest(t) + ts, cleanup := createTidbTestSuite(t) defer cleanup() ts.runTestAuth(t) @@ -214,7 +205,7 @@ func TestAuth(t *testing.T) { func TestIssues(t *testing.T) { t.Parallel() - ts, cleanup := createTiDBTest(t) + ts, cleanup := createTidbTestSuite(t) defer cleanup() ts.runTestIssue3662(t) @@ -224,14 +215,14 @@ func TestIssues(t *testing.T) { func TestDBNameEscape(t *testing.T) { t.Parallel() - ts, cleanup := createTiDBTest(t) + ts, cleanup := createTidbTestSuite(t) defer cleanup() ts.runTestDBNameEscape(t) } func TestResultFieldTableIsNull(t *testing.T) { t.Parallel() - ts, cleanup := createTiDBTest(t) + ts, cleanup := createTidbTestSuite(t) defer cleanup() ts.runTestResultFieldTableIsNull(t) @@ -239,7 +230,7 @@ func TestResultFieldTableIsNull(t *testing.T) { func TestStatusAPI(t *testing.T) { t.Parallel() - ts, cleanup := createTiDBTest(t) + ts, cleanup := createTidbTestSuite(t) defer cleanup() ts.runTestStatusAPI(t) @@ -247,7 +238,7 @@ func TestStatusAPI(t *testing.T) { func TestStatusPort(t *testing.T) { t.Parallel() - ts, cleanup := createTiDBTest(t) + ts, cleanup := createTidbTestSuite(t) defer cleanup() cfg := newTestConfig() @@ -264,7 +255,7 @@ func TestStatusPort(t *testing.T) { func TestStatusAPIWithTLS(t *testing.T) { t.Parallel() - ts, cleanup := createTiDBTest(t) + ts, cleanup := createTidbTestSuite(t) defer cleanup() caCert, caKey, err := generateCert(0, "TiDB CA 2", nil, nil, "/tmp/ca-key-2.pem", "/tmp/ca-cert-2.pem") @@ -311,7 +302,7 @@ func TestStatusAPIWithTLS(t *testing.T) { func TestStatusAPIWithTLSCNCheck(t *testing.T) { t.Parallel() - ts, cleanup := createTiDBTest(t) + ts, cleanup := createTidbTestSuite(t) defer cleanup() caPath := filepath.Join(os.TempDir(), "ca-cert-cn.pem") @@ -391,7 +382,7 @@ func newTLSHttpClient(t *testing.T, caFile, certFile, keyFile string) *http.Clie func TestMultiStatements(t *testing.T) { t.Parallel() - ts, cleanup := createTiDBTest(t) + ts, cleanup := createTidbTestSuite(t) defer cleanup() ts.runFailedTestMultiStatements(t) @@ -406,7 +397,7 @@ func TestSocketForwarding(t *testing.T) { socketFile := tempDir + "/tidbtest.sock" // Unix Socket does not work on Windows, so '/' should be OK defer os.RemoveAll(tempDir) - ts, cleanup := createTiDBTest(t) + ts, cleanup := createTidbTestSuite(t) defer cleanup() cli := newTestingServerClient() @@ -450,7 +441,7 @@ func TestSocket(t *testing.T) { cfg.Host = "" cfg.Status.ReportStatus = false - ts, cleanup := createTiDBTest(t) + ts, cleanup := createTidbTestSuite(t) defer cleanup() server, err := NewServer(cfg, ts.tidbdrv) @@ -487,7 +478,7 @@ func TestSocketAndIp(t *testing.T) { cfg.Port = cli.port cfg.Status.ReportStatus = false - ts, cleanup := createTiDBTest(t) + ts, cleanup := createTidbTestSuite(t) defer cleanup() server, err := NewServer(cfg, ts.tidbdrv) @@ -650,7 +641,7 @@ func TestOnlySocket(t *testing.T) { cfg.Host = "" // No network interface listening for mysql traffic cfg.Status.ReportStatus = false - ts, cleanup := createTiDBTest(t) + ts, cleanup := createTidbTestSuite(t) defer cleanup() server, err := NewServer(cfg, ts.tidbdrv) @@ -886,7 +877,7 @@ func registerTLSConfig(configName string, caCertPath string, clientCertPath stri func TestSystemTimeZone(t *testing.T) { t.Parallel() - ts, cleanup := createTiDBTest(t) + ts, cleanup := createTidbTestSuite(t) defer cleanup() tk := testkit.NewTestKit(t, ts.store) @@ -904,7 +895,7 @@ func TestSystemTimeZone(t *testing.T) { func TestClientWithCollation(t *testing.T) { t.Parallel() - ts, cleanup := createTiDBTest(t) + ts, cleanup := createTidbTestSuite(t) defer cleanup() ts.runTestClientWithCollation(t) @@ -912,7 +903,7 @@ func TestClientWithCollation(t *testing.T) { func TestCreateTableFlen(t *testing.T) { t.Parallel() - ts, cleanup := createTiDBTest(t) + ts, cleanup := createTidbTestSuite(t) defer cleanup() // issue #4540 @@ -986,7 +977,7 @@ func Execute(ctx context.Context, qc *TiDBContext, sql string) (ResultSet, error func TestShowTablesFlen(t *testing.T) { t.Parallel() - ts, cleanup := createTiDBTest(t) + ts, cleanup := createTidbTestSuite(t) defer cleanup() qctx, err := ts.tidbdrv.OpenCtx(uint64(0), 0, uint8(tmysql.DefaultCollationID), "test", nil) @@ -1018,7 +1009,7 @@ func checkColNames(t *testing.T, columns []*ColumnInfo, names ...string) { func TestFieldList(t *testing.T) { t.Parallel() - ts, cleanup := createTiDBTest(t) + ts, cleanup := createTidbTestSuite(t) defer cleanup() qctx, err := ts.tidbdrv.OpenCtx(uint64(0), 0, uint8(tmysql.DefaultCollationID), "test", nil) @@ -1101,28 +1092,28 @@ func TestFieldList(t *testing.T) { func TestClientErrors(t *testing.T) { t.Parallel() - ts, cleanup := createTiDBTest(t) + ts, cleanup := createTidbTestSuite(t) defer cleanup() ts.runTestInfoschemaClientErrors(t) } func TestInitConnect(t *testing.T) { t.Parallel() - ts, cleanup := createTiDBTest(t) + ts, cleanup := createTidbTestSuite(t) defer cleanup() ts.runTestInitConnect(t) } func TestSumAvg(t *testing.T) { t.Parallel() - ts, cleanup := createTiDBTest(t) + ts, cleanup := createTidbTestSuite(t) defer cleanup() ts.runTestSumAvg(t) } func TestNullFlag(t *testing.T) { t.Parallel() - ts, cleanup := createTiDBTest(t) + ts, cleanup := createTidbTestSuite(t) defer cleanup() qctx, err := ts.tidbdrv.OpenCtx(uint64(0), 0, uint8(tmysql.DefaultCollationID), "test", nil) @@ -1191,7 +1182,7 @@ func TestNullFlag(t *testing.T) { func TestNO_DEFAULT_VALUEFlag(t *testing.T) { t.Parallel() - ts, cleanup := createTiDBTest(t) + ts, cleanup := createTidbTestSuite(t) defer cleanup() // issue #21465 @@ -1215,7 +1206,7 @@ func TestNO_DEFAULT_VALUEFlag(t *testing.T) { func TestGracefulShutdown(t *testing.T) { t.Parallel() - ts, cleanup := createTiDBTest(t) + ts, cleanup := createTidbTestSuite(t) defer cleanup() cli := newTestServerClient() @@ -1257,7 +1248,7 @@ func TestGracefulShutdown(t *testing.T) { func TestPessimisticInsertSelectForUpdate(t *testing.T) { t.Parallel() - ts, cleanup := createTiDBTest(t) + ts, cleanup := createTidbTestSuite(t) defer cleanup() qctx, err := ts.tidbdrv.OpenCtx(uint64(0), 0, uint8(tmysql.DefaultCollationID), "test", nil) @@ -1286,7 +1277,7 @@ type collectorWrapper struct { } func TestTopSQLCPUProfile(t *testing.T) { - ts, cleanup := createTiDBTestTopSQL(t) + ts, cleanup := createTidbTestTopSQLSuite(t) defer cleanup() db, err := sql.Open("mysql", ts.getDSN()) @@ -1513,7 +1504,7 @@ func TestTopSQLCPUProfile(t *testing.T) { func TestTopSQLAgent(t *testing.T) { t.Skip("unstable, skip it and fix it before 20210702") - ts, cleanup := createTiDBTestTopSQL(t) + ts, cleanup := createTidbTestTopSQLSuite(t) defer cleanup() db, err := sql.Open("mysql", ts.getDSN()) require.NoError(t, err, "Error connecting") @@ -1665,7 +1656,7 @@ func TestTopSQLAgent(t *testing.T) { cancel5() } -func (ts *tidbTestTopSQL) loopExec(ctx context.Context, t *testing.T, fn func(db *sql.DB)) { +func (ts *tidbTestTopSQLSuite) loopExec(ctx context.Context, t *testing.T, fn func(db *sql.DB)) { db, err := sql.Open("mysql", ts.getDSN()) require.NoError(t, err, "Error connecting") defer func() { From 09f7ff49a847e488c3ac2b88efeb35ba7af85644 Mon Sep 17 00:00:00 2001 From: yedamo Date: Sat, 20 Nov 2021 20:33:57 +0800 Subject: [PATCH 55/55] server: remove scaffolding for the migration --- server/server_test.go | 91 ++++++++++++++++---------------------- server/tidb_serial_test.go | 10 ++--- server/tidb_test.go | 12 ++--- 3 files changed, 50 insertions(+), 63 deletions(-) diff --git a/server/server_test.go b/server/server_test.go index add08f06202ed..210e58caed3f8 100644 --- a/server/server_test.go +++ b/server/server_test.go @@ -67,19 +67,6 @@ func newTestServerClient() *testServerClient { } } -type testingServerClient struct { - testServerClient -} - -// newTestServerClient return a testingServerClient with unique address -func newTestingServerClient() *testingServerClient { - return &testingServerClient{testServerClient{ - port: 0, - statusPort: 0, - statusScheme: "http", - }} -} - // statusURL return the full URL of a status path func (cli *testServerClient) statusURL(path string) string { return fmt.Sprintf("%s://localhost:%d%s", cli.statusScheme, cli.statusPort, path) @@ -117,7 +104,7 @@ func (cli *testServerClient) getDSN(overriders ...configOverrider) string { } // runTests runs tests using the default database `test`. -func (cli *testingServerClient) runTests(t *testing.T, overrider configOverrider, tests ...func(dbt *testkit.DBTestKit)) { +func (cli *testServerClient) runTests(t *testing.T, overrider configOverrider, tests ...func(dbt *testkit.DBTestKit)) { db, err := sql.Open("mysql", cli.getDSN(overrider)) require.NoErrorf(t, err, "Error connecting") defer func() { @@ -132,7 +119,7 @@ func (cli *testingServerClient) runTests(t *testing.T, overrider configOverrider } // runTestsOnNewDB runs tests using a specified database which will be created before the test and destroyed after the test. -func (cli *testingServerClient) runTestsOnNewDB(t *testing.T, overrider configOverrider, dbName string, tests ...func(dbt *testkit.DBTestKit)) { +func (cli *testServerClient) runTestsOnNewDB(t *testing.T, overrider configOverrider, dbName string, tests ...func(dbt *testkit.DBTestKit)) { dsn := cli.getDSN(overrider, func(config *mysql.Config) { config.DBName = "" }) @@ -168,7 +155,7 @@ func (cli *testingServerClient) runTestsOnNewDB(t *testing.T, overrider configOv } } -func (cli *testingServerClient) runTestRegression(t *testing.T, overrider configOverrider, dbName string) { +func (cli *testServerClient) runTestRegression(t *testing.T, overrider configOverrider, dbName string) { cli.runTestsOnNewDB(t, overrider, dbName, func(dbt *testkit.DBTestKit) { // Show the user dbt.MustExec("select user()") @@ -246,7 +233,7 @@ func (cli *testingServerClient) runTestRegression(t *testing.T, overrider config }) } -func (cli *testingServerClient) runTestPrepareResultFieldType(t *testing.T) { +func (cli *testServerClient) runTestPrepareResultFieldType(t *testing.T) { var param int64 = 83 cli.runTests(t, nil, func(dbt *testkit.DBTestKit) { stmt, err := dbt.GetDB().Prepare(`SELECT ?`) @@ -266,7 +253,7 @@ func (cli *testingServerClient) runTestPrepareResultFieldType(t *testing.T) { }) } -func (cli *testingServerClient) runTestSpecialType(t *testing.T) { +func (cli *testServerClient) runTestSpecialType(t *testing.T) { cli.runTestsOnNewDB(t, nil, "SpecialType", func(dbt *testkit.DBTestKit) { dbt.MustExec("create table test (a decimal(10, 5), b datetime, c time, d bit(8))") dbt.MustExec("insert test values (1.4, '2012-12-21 12:12:12', '4:23:34', b'1000')") @@ -284,7 +271,7 @@ func (cli *testingServerClient) runTestSpecialType(t *testing.T) { }) } -func (cli *testingServerClient) runTestClientWithCollation(t *testing.T) { +func (cli *testServerClient) runTestClientWithCollation(t *testing.T) { cli.runTests(t, func(config *mysql.Config) { config.Collation = "utf8mb4_general_ci" }, func(dbt *testkit.DBTestKit) { @@ -320,7 +307,7 @@ func (cli *testingServerClient) runTestClientWithCollation(t *testing.T) { }) } -func (cli *testingServerClient) runTestPreparedString(t *testing.T) { +func (cli *testServerClient) runTestPreparedString(t *testing.T) { cli.runTestsOnNewDB(t, nil, "PreparedString", func(dbt *testkit.DBTestKit) { dbt.MustExec("create table test (a char(10), b char(10))") dbt.MustExec("insert test values (?, ?)", "abcdeabcde", "abcde") @@ -337,7 +324,7 @@ func (cli *testingServerClient) runTestPreparedString(t *testing.T) { // runTestPreparedTimestamp does not really cover binary timestamp format, because MySQL driver in golang // does not use this format. MySQL driver in golang will convert the timestamp to a string. // This case guarantees it could work. -func (cli *testingServerClient) runTestPreparedTimestamp(t *testing.T) { +func (cli *testServerClient) runTestPreparedTimestamp(t *testing.T) { cli.runTestsOnNewDB(t, nil, "prepared_timestamp", func(dbt *testkit.DBTestKit) { dbt.MustExec("create table test (a timestamp, b time)") dbt.MustExec("set time_zone='+00:00'") @@ -358,7 +345,7 @@ func (cli *testingServerClient) runTestPreparedTimestamp(t *testing.T) { }) } -func (cli *testingServerClient) runTestLoadDataWithSelectIntoOutfile(t *testing.T, server *Server) { +func (cli *testServerClient) runTestLoadDataWithSelectIntoOutfile(t *testing.T, server *Server) { cli.runTestsOnNewDB(t, func(config *mysql.Config) { config.AllowAllFiles = true config.Params["sql_mode"] = "''" @@ -405,7 +392,7 @@ func (cli *testingServerClient) runTestLoadDataWithSelectIntoOutfile(t *testing. }) } -func (cli *testingServerClient) runTestLoadDataForSlowLog(t *testing.T, server *Server) { +func (cli *testServerClient) runTestLoadDataForSlowLog(t *testing.T, server *Server) { path := "/tmp/load_data_test.csv" fp, err := os.OpenFile(path, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0600) require.NoError(t, err) @@ -463,7 +450,7 @@ func (cli *testingServerClient) runTestLoadDataForSlowLog(t *testing.T, server * }) } -func (cli *testingServerClient) prepareLoadDataFile(t *testing.T, path string, rows ...string) { +func (cli *testServerClient) prepareLoadDataFile(t *testing.T, path string, rows ...string) { fp, err := os.OpenFile(path, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0600) require.NoError(t, err) require.NotNil(t, fp) @@ -479,7 +466,7 @@ func (cli *testingServerClient) prepareLoadDataFile(t *testing.T, path string, r require.NoError(t, err) } -func (cli *testingServerClient) runTestLoadDataAutoRandom(t *testing.T) { +func (cli *testServerClient) runTestLoadDataAutoRandom(t *testing.T) { path := "/tmp/load_data_txn_error.csv" fp, err := os.OpenFile(path, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0600) @@ -534,7 +521,7 @@ func (cli *testingServerClient) runTestLoadDataAutoRandom(t *testing.T) { }) } -func (cli *testingServerClient) runTestLoadDataAutoRandomWithSpecialTerm(t *testing.T) { +func (cli *testServerClient) runTestLoadDataAutoRandomWithSpecialTerm(t *testing.T) { path := "/tmp/load_data_txn_error_term.csv" fp, err := os.OpenFile(path, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0600) @@ -591,7 +578,7 @@ func (cli *testingServerClient) runTestLoadDataAutoRandomWithSpecialTerm(t *test }) } -func (cli *testingServerClient) runTestLoadDataForListPartition(t *testing.T) { +func (cli *testServerClient) runTestLoadDataForListPartition(t *testing.T) { path := "/tmp/load_data_list_partition.csv" defer func() { _ = os.Remove(path) @@ -640,7 +627,7 @@ func (cli *testingServerClient) runTestLoadDataForListPartition(t *testing.T) { }) } -func (cli *testingServerClient) runTestLoadDataForListPartition2(t *testing.T) { +func (cli *testServerClient) runTestLoadDataForListPartition2(t *testing.T) { path := "/tmp/load_data_list_partition.csv" defer func() { _ = os.Remove(path) @@ -689,7 +676,7 @@ func (cli *testingServerClient) runTestLoadDataForListPartition2(t *testing.T) { }) } -func (cli *testingServerClient) runTestLoadDataForListColumnPartition(t *testing.T) { +func (cli *testServerClient) runTestLoadDataForListColumnPartition(t *testing.T) { path := "/tmp/load_data_list_partition.csv" defer func() { _ = os.Remove(path) @@ -738,7 +725,7 @@ func (cli *testingServerClient) runTestLoadDataForListColumnPartition(t *testing }) } -func (cli *testingServerClient) runTestLoadDataForListColumnPartition2(t *testing.T) { +func (cli *testServerClient) runTestLoadDataForListColumnPartition2(t *testing.T) { path := "/tmp/load_data_list_partition.csv" defer func() { _ = os.Remove(path) @@ -792,7 +779,7 @@ func (cli *testingServerClient) runTestLoadDataForListColumnPartition2(t *testin }) } -func (cli *testingServerClient) checkRows(t *testing.T, rows *sql.Rows, expectedRows ...string) { +func (cli *testServerClient) checkRows(t *testing.T, rows *sql.Rows, expectedRows ...string) { buf := bytes.NewBuffer(nil) result := make([]string, 0, 2) for rows.Next() { @@ -823,7 +810,7 @@ func (cli *testingServerClient) checkRows(t *testing.T, rows *sql.Rows, expected require.Equal(t, strings.Join(expectedRows, "\n"), strings.Join(result, "\n")) } -func (cli *testingServerClient) runTestLoadData(t *testing.T, server *Server) { +func (cli *testServerClient) runTestLoadData(t *testing.T, server *Server) { // create a file and write data. path := "/tmp/load_data_test.csv" fp, err := os.Create(path) @@ -1377,7 +1364,7 @@ func (cli *testingServerClient) runTestLoadData(t *testing.T, server *Server) { }) } -func (cli *testingServerClient) runTestConcurrentUpdate(t *testing.T) { +func (cli *testServerClient) runTestConcurrentUpdate(t *testing.T) { dbName := "Concurrent" cli.runTestsOnNewDB(t, func(config *mysql.Config) { config.Params["sql_mode"] = "''" @@ -1410,7 +1397,7 @@ func (cli *testingServerClient) runTestConcurrentUpdate(t *testing.T) { }) } -func (cli *testingServerClient) runTestExplainForConn(t *testing.T) { +func (cli *testServerClient) runTestExplainForConn(t *testing.T) { cli.runTestsOnNewDB(t, nil, "explain_for_conn", func(dbt *testkit.DBTestKit) { dbt.MustExec("drop table if exists t") dbt.MustExec("create table t (a int key, b int)") @@ -1432,7 +1419,7 @@ func (cli *testingServerClient) runTestExplainForConn(t *testing.T) { }) } -func (cli *testingServerClient) runTestErrorCode(t *testing.T) { +func (cli *testServerClient) runTestErrorCode(t *testing.T) { cli.runTestsOnNewDB(t, nil, "ErrorCode", func(dbt *testkit.DBTestKit) { dbt.MustExec("create table test (c int PRIMARY KEY);") dbt.MustExec("insert into test values (1);") @@ -1504,7 +1491,7 @@ func checkErrorCode(t *testing.T, e error, codes ...uint16) { require.Truef(t, isMatchCode, "got err %v, expected err codes %v", me, codes) } -func (cli *testingServerClient) runTestAuth(t *testing.T) { +func (cli *testServerClient) runTestAuth(t *testing.T) { cli.runTests(t, nil, func(dbt *testkit.DBTestKit) { dbt.MustExec(`CREATE USER 'authtest'@'%' IDENTIFIED BY '123';`) dbt.MustExec(`CREATE ROLE 'authtest_r1'@'%';`) @@ -1558,7 +1545,7 @@ func (cli *testingServerClient) runTestAuth(t *testing.T) { }) } -func (cli *testingServerClient) runTestIssue3662(t *testing.T) { +func (cli *testServerClient) runTestIssue3662(t *testing.T) { db, err := sql.Open("mysql", cli.getDSN(func(config *mysql.Config) { config.DBName = "non_existing_schema" })) @@ -1576,7 +1563,7 @@ func (cli *testingServerClient) runTestIssue3662(t *testing.T) { require.Equal(t, "Error 1049: Unknown database 'non_existing_schema'", err.Error()) } -func (cli *testingServerClient) runTestIssue3680(t *testing.T) { +func (cli *testServerClient) runTestIssue3680(t *testing.T) { db, err := sql.Open("mysql", cli.getDSN(func(config *mysql.Config) { config.User = "non_existing_user" })) @@ -1594,7 +1581,7 @@ func (cli *testingServerClient) runTestIssue3680(t *testing.T) { require.Equal(t, "Error 1045: Access denied for user 'non_existing_user'@'127.0.0.1' (using password: NO)", err.Error()) } -func (cli *testingServerClient) runTestIssue22646(t *testing.T) { +func (cli *testServerClient) runTestIssue22646(t *testing.T) { cli.runTests(t, nil, func(dbt *testkit.DBTestKit) { c1 := make(chan string, 1) go func() { @@ -1610,7 +1597,7 @@ func (cli *testingServerClient) runTestIssue22646(t *testing.T) { }) } -func (cli *testingServerClient) runTestIssue3682(t *testing.T) { +func (cli *testServerClient) runTestIssue3682(t *testing.T) { cli.runTests(t, nil, func(dbt *testkit.DBTestKit) { dbt.MustExec(`CREATE USER 'issue3682'@'%' IDENTIFIED BY '123';`) dbt.MustExec(`GRANT ALL on test.* to 'issue3682'`) @@ -1637,7 +1624,7 @@ func (cli *testingServerClient) runTestIssue3682(t *testing.T) { require.Equal(t, "Error 1045: Access denied for user 'issue3682'@'127.0.0.1' (using password: YES)", err.Error()) } -func (cli *testingServerClient) runTestDBNameEscape(t *testing.T) { +func (cli *testServerClient) runTestDBNameEscape(t *testing.T) { cli.runTests(t, nil, func(dbt *testkit.DBTestKit) { dbt.MustExec("CREATE DATABASE `aa-a`;") }) @@ -1649,7 +1636,7 @@ func (cli *testingServerClient) runTestDBNameEscape(t *testing.T) { }) } -func (cli *testingServerClient) runTestResultFieldTableIsNull(t *testing.T) { +func (cli *testServerClient) runTestResultFieldTableIsNull(t *testing.T) { cli.runTestsOnNewDB(t, func(config *mysql.Config) { config.Params["sql_mode"] = "''" }, "ResultFieldTableIsNull", func(dbt *testkit.DBTestKit) { @@ -1659,7 +1646,7 @@ func (cli *testingServerClient) runTestResultFieldTableIsNull(t *testing.T) { }) } -func (cli *testingServerClient) runTestStatusAPI(t *testing.T) { +func (cli *testServerClient) runTestStatusAPI(t *testing.T) { resp, err := cli.fetchStatus("/status") require.NoError(t, err) defer resp.Body.Close() @@ -1675,7 +1662,7 @@ func (cli *testingServerClient) runTestStatusAPI(t *testing.T) { // disabled by default for security reasons. Lets ensure that the behavior // is correct. -func (cli *testingServerClient) runFailedTestMultiStatements(t *testing.T) { +func (cli *testServerClient) runFailedTestMultiStatements(t *testing.T) { cli.runTestsOnNewDB(t, nil, "FailedMultiStatements", func(dbt *testkit.DBTestKit) { // Default is now OFF in new installations. @@ -1737,7 +1724,7 @@ func (cli *testingServerClient) runFailedTestMultiStatements(t *testing.T) { }) } -func (cli *testingServerClient) runTestMultiStatements(t *testing.T) { +func (cli *testServerClient) runTestMultiStatements(t *testing.T) { cli.runTestsOnNewDB(t, func(config *mysql.Config) { config.Params["multiStatements"] = "true" @@ -1793,7 +1780,7 @@ func (cli *testingServerClient) runTestMultiStatements(t *testing.T) { }) } -func (cli *testingServerClient) runTestStmtCount(t *testing.T) { +func (cli *testServerClient) runTestStmtCount(t *testing.T) { cli.runTestsOnNewDB(t, nil, "StatementCount", func(dbt *testkit.DBTestKit) { originStmtCnt := getStmtCnt(string(cli.getMetrics(t))) @@ -1828,7 +1815,7 @@ func (cli *testingServerClient) runTestStmtCount(t *testing.T) { }) } -func (cli *testingServerClient) runTestTLSConnection(t *testing.T, overrider configOverrider) error { +func (cli *testServerClient) runTestTLSConnection(t *testing.T, overrider configOverrider) error { dsn := cli.getDSN(overrider) db, err := sql.Open("mysql", dsn) require.NoError(t, err) @@ -1843,7 +1830,7 @@ func (cli *testingServerClient) runTestTLSConnection(t *testing.T, overrider con return err } -func (cli *testingServerClient) runReloadTLS(t *testing.T, overrider configOverrider, errorNoRollback bool) error { +func (cli *testServerClient) runReloadTLS(t *testing.T, overrider configOverrider, errorNoRollback bool) error { db, err := sql.Open("mysql", cli.getDSN(overrider)) require.NoError(t, err) defer func() { @@ -1858,7 +1845,7 @@ func (cli *testingServerClient) runReloadTLS(t *testing.T, overrider configOverr return err } -func (cli *testingServerClient) runTestSumAvg(t *testing.T) { +func (cli *testServerClient) runTestSumAvg(t *testing.T) { cli.runTests(t, nil, func(dbt *testkit.DBTestKit) { dbt.MustExec("create table sumavg (a int, b decimal, c double)") dbt.MustExec("insert sumavg values (1, 1, 1)") @@ -1880,7 +1867,7 @@ func (cli *testingServerClient) runTestSumAvg(t *testing.T) { }) } -func (cli *testingServerClient) getMetrics(t *testing.T) []byte { +func (cli *testServerClient) getMetrics(t *testing.T) []byte { resp, err := cli.fetchStatus("/metrics") require.NoError(t, err) content, err := io.ReadAll(resp.Body) @@ -1942,7 +1929,7 @@ func (cli *testServerClient) waitUntilServerOnline() { } } -func (cli *testingServerClient) runTestInitConnect(t *testing.T) { +func (cli *testServerClient) runTestInitConnect(t *testing.T) { cli.runTests(t, nil, func(dbt *testkit.DBTestKit) { dbt.MustExec(`SET GLOBAL init_connect="insert into test.ts VALUES (NOW());SET @a=1;"`) @@ -1999,7 +1986,7 @@ func (cli *testingServerClient) runTestInitConnect(t *testing.T) { // Client errors are only incremented when using the TiDB Server protocol, // and not internal SQL statements. Thus, this test is in the server-test suite. -func (cli *testingServerClient) runTestInfoschemaClientErrors(t *testing.T) { +func (cli *testServerClient) runTestInfoschemaClientErrors(t *testing.T) { cli.runTestsOnNewDB(t, nil, "clientErrors", func(dbt *testkit.DBTestKit) { clientErrors := []struct { diff --git a/server/tidb_serial_test.go b/server/tidb_serial_test.go index 9bc6b65198a8d..0431baa32fa8f 100644 --- a/server/tidb_serial_test.go +++ b/server/tidb_serial_test.go @@ -102,7 +102,7 @@ func TestTLSAuto(t *testing.T) { connOverrider := func(config *mysql.Config) { config.TLSConfig = "skip-verify" } - cli := newTestingServerClient() + cli := newTestServerClient() cfg := newTestConfig() cfg.Socket = "" cfg.Port = cli.port @@ -158,7 +158,7 @@ func TestTLSBasic(t *testing.T) { connOverrider := func(config *mysql.Config) { config.TLSConfig = "skip-verify" } - cli := newTestingServerClient() + cli := newTestServerClient() cfg := newTestConfig() cfg.Socket = "" cfg.Port = cli.port @@ -230,7 +230,7 @@ func TestTLSVerify(t *testing.T) { }() // Start the server with TLS & CA, if the client presents its certificate, the certificate will be verified. - cli := newTestingServerClient() + cli := newTestServerClient() cfg := newTestConfig() cfg.Socket = "" cfg.Port = cli.port @@ -298,7 +298,7 @@ func TestErrorNoRollback(t *testing.T) { os.Remove("/tmp/client-cert-rollback.pem") }() - cli := newTestingServerClient() + cli := newTestServerClient() cfg := newTestConfig() cfg.Socket = "" cfg.Port = cli.port @@ -418,7 +418,7 @@ func TestReloadTLS(t *testing.T) { }() // try old cert used in startup configuration. - cli := newTestingServerClient() + cli := newTestServerClient() cfg := newTestConfig() cfg.Socket = "" cfg.Port = cli.port diff --git a/server/tidb_test.go b/server/tidb_test.go index a306ecd7f1f66..01a19d70df6d1 100644 --- a/server/tidb_test.go +++ b/server/tidb_test.go @@ -57,7 +57,7 @@ import ( ) type tidbTestSuite struct { - *testingServerClient + *testServerClient tidbdrv *TiDBDriver server *Server domain *domain.Domain @@ -65,7 +65,7 @@ type tidbTestSuite struct { } func createTidbTestSuite(t *testing.T) (*tidbTestSuite, func()) { - ts := &tidbTestSuite{testingServerClient: newTestingServerClient()} + ts := &tidbTestSuite{testServerClient: newTestServerClient()} // setup tidbTestSuite var err error @@ -400,7 +400,7 @@ func TestSocketForwarding(t *testing.T) { ts, cleanup := createTidbTestSuite(t) defer cleanup() - cli := newTestingServerClient() + cli := newTestServerClient() cfg := newTestConfig() cfg.Socket = socketFile cfg.Port = cli.port @@ -454,7 +454,7 @@ func TestSocket(t *testing.T) { defer server.Close() // a fake server client, config is override, just used to run tests - cli := newTestingServerClient() + cli := newTestServerClient() cli.runTestRegression(t, func(config *mysql.Config) { config.User = "root" config.Net = "unix" @@ -472,7 +472,7 @@ func TestSocketAndIp(t *testing.T) { socketFile := tempDir + "/tidbtest.sock" // Unix Socket does not work on Windows, so '/' should be OK defer os.RemoveAll(tempDir) - cli := newTestingServerClient() + cli := newTestServerClient() cfg := newTestConfig() cfg.Socket = socketFile cfg.Port = cli.port @@ -635,7 +635,7 @@ func TestOnlySocket(t *testing.T) { socketFile := tempDir + "/tidbtest.sock" // Unix Socket does not work on Windows, so '/' should be OK defer os.RemoveAll(tempDir) - cli := newTestingServerClient() + cli := newTestServerClient() cfg := newTestConfig() cfg.Socket = socketFile cfg.Host = "" // No network interface listening for mysql traffic