diff --git a/DEPS.bzl b/DEPS.bzl index 97d5e63b1776b..32f321378fa33 100644 --- a/DEPS.bzl +++ b/DEPS.bzl @@ -167,8 +167,8 @@ def go_deps(): name = "com_github_blacktear23_go_proxyprotocol", build_file_proto_mode = "disable_global", importpath = "github.com/blacktear23/go-proxyprotocol", - sum = "h1:rQlvB2AYWme2bIB18r/SipGiMEVJYE9U0z+MGoU/LtQ=", - version = "v0.0.0-20180807104634-af7a81e8dd0d", + sum = "h1:WmMmtZanGEfIHnJN9N3A4Pl6mM69D+GxEph2eOaCf7g=", + version = "v1.0.0", ) go_repository( name = "com_github_burntsushi_toml", diff --git a/br/pkg/lightning/backend/local/local.go b/br/pkg/lightning/backend/local/local.go old mode 100644 new mode 100755 index 11b44a6ba19c8..13773dc6d2ee4 --- a/br/pkg/lightning/backend/local/local.go +++ b/br/pkg/lightning/backend/local/local.go @@ -232,8 +232,9 @@ type local struct { errorMgr *errormanager.ErrorManager importClientFactory ImportClientFactory - bufferPool *membuf.Pool - metrics *metric.Metrics + bufferPool *membuf.Pool + metrics *metric.Metrics + writeLimiter StoreWriteLimiter } func openDuplicateDB(storeDir string) (*pebble.DB, error) { @@ -308,6 +309,12 @@ func NewLocalBackend( if duplicateDetection { keyAdapter = dupDetectKeyAdapter{} } + var writeLimiter StoreWriteLimiter + if cfg.TikvImporter.StoreWriteBWLimit > 0 { + writeLimiter = newStoreWriteLimiter(int(cfg.TikvImporter.StoreWriteBWLimit)) + } else { + writeLimiter = noopStoreWriteLimiter{} + } local := &local{ engines: sync.Map{}, pdCtl: pdCtl, @@ -334,6 +341,7 @@ func NewLocalBackend( errorMgr: errorMgr, importClientFactory: importClientFactory, bufferPool: membuf.NewPool(membuf.WithAllocator(manual.Allocator{})), + writeLimiter: writeLimiter, } if m, ok := metric.FromContext(ctx); ok { local.metrics = m @@ -784,6 +792,7 @@ func (local *local) WriteToTiKV( leaderID := region.Leader.GetId() clients := make([]sst.ImportSST_WriteClient, 0, len(region.Region.GetPeers())) + storeIDs := make([]uint64, 0, len(region.Region.GetPeers())) requests := make([]*sst.WriteRequest, 0, len(region.Region.GetPeers())) for _, peer := range region.Region.GetPeers() { cli, err := local.getImportClient(ctx, peer.StoreId) @@ -812,6 +821,7 @@ func (local *local) WriteToTiKV( } clients = append(clients, wstream) requests = append(requests, req) + storeIDs = append(storeIDs, peer.StoreId) } bytesBuf := local.bufferPool.NewBuffer() @@ -819,43 +829,57 @@ func (local *local) WriteToTiKV( pairs := make([]*sst.Pair, 0, local.batchWriteKVPairs) count := 0 size := int64(0) + totalSize := int64(0) totalCount := int64(0) - firstLoop := true // if region-split-size <= 96MiB, we bump the threshold a bit to avoid too many retry split // because the range-properties is not 100% accurate regionMaxSize := regionSplitSize if regionSplitSize <= int64(config.SplitRegionSize) { regionMaxSize = regionSplitSize * 4 / 3 } + // Set a lower flush limit to make the speed of write more smooth. + flushLimit := int64(local.writeLimiter.Limit() / 10) + + flushKVs := func() error { + for i := range clients { + if err := local.writeLimiter.WaitN(ctx, storeIDs[i], int(size)); err != nil { + return errors.Trace(err) + } + requests[i].Chunk.(*sst.WriteRequest_Batch).Batch.Pairs = pairs[:count] + if err := clients[i].Send(requests[i]); err != nil { + return errors.Trace(err) + } + } + return nil + } for iter.First(); iter.Valid(); iter.Next() { - size += int64(len(iter.Key()) + len(iter.Value())) + kvSize := int64(len(iter.Key()) + len(iter.Value())) // here we reuse the `*sst.Pair`s to optimize object allocation - if firstLoop { + if count < len(pairs) { + pairs[count].Key = bytesBuf.AddBytes(iter.Key()) + pairs[count].Value = bytesBuf.AddBytes(iter.Value()) + } else { pair := &sst.Pair{ Key: bytesBuf.AddBytes(iter.Key()), Value: bytesBuf.AddBytes(iter.Value()), } pairs = append(pairs, pair) - } else { - pairs[count].Key = bytesBuf.AddBytes(iter.Key()) - pairs[count].Value = bytesBuf.AddBytes(iter.Value()) } count++ totalCount++ + size += kvSize + totalSize += kvSize - if count >= local.batchWriteKVPairs { - for i := range clients { - requests[i].Chunk.(*sst.WriteRequest_Batch).Batch.Pairs = pairs[:count] - if err := clients[i].Send(requests[i]); err != nil { - return nil, Range{}, stats, errors.Trace(err) - } + if count >= local.batchWriteKVPairs || size >= flushLimit { + if err := flushKVs(); err != nil { + return nil, Range{}, stats, err } count = 0 + size = 0 bytesBuf.Reset() - firstLoop = false } - if size >= regionMaxSize || totalCount >= regionSplitKeys { + if totalSize >= regionMaxSize || totalCount >= regionSplitKeys { break } } @@ -865,12 +889,12 @@ func (local *local) WriteToTiKV( } if count > 0 { - for i := range clients { - requests[i].Chunk.(*sst.WriteRequest_Batch).Batch.Pairs = pairs[:count] - if err := clients[i].Send(requests[i]); err != nil { - return nil, Range{}, stats, errors.Trace(err) - } + if err := flushKVs(); err != nil { + return nil, Range{}, stats, err } + count = 0 + size = 0 + bytesBuf.Reset() } var leaderPeerMetas []*sst.SSTMeta @@ -913,7 +937,7 @@ func (local *local) WriteToTiKV( logutil.Region(region.Region), logutil.Leader(region.Leader)) } stats.count = totalCount - stats.totalBytes = size + stats.totalBytes = totalSize return leaderPeerMetas, finishedRange, stats, nil } diff --git a/br/pkg/lightning/backend/local/localhelper.go b/br/pkg/lightning/backend/local/localhelper.go index 98413b20e71e0..c4aaae30db37b 100644 --- a/br/pkg/lightning/backend/local/localhelper.go +++ b/br/pkg/lightning/backend/local/localhelper.go @@ -18,6 +18,7 @@ import ( "bytes" "context" "database/sql" + "math" "regexp" "runtime" "sort" @@ -40,6 +41,7 @@ import ( "go.uber.org/multierr" "go.uber.org/zap" "golang.org/x/sync/errgroup" + "golang.org/x/time/rate" ) const ( @@ -592,3 +594,75 @@ func intersectRange(region *metapb.Region, rg Range) Range { return Range{start: startKey, end: endKey} } + +type StoreWriteLimiter interface { + WaitN(ctx context.Context, storeID uint64, n int) error + Limit() int +} + +type storeWriteLimiter struct { + rwm sync.RWMutex + limiters map[uint64]*rate.Limiter + limit int + burst int +} + +func newStoreWriteLimiter(limit int) *storeWriteLimiter { + var burst int + // Allow burst of at most 20% of the limit. + if limit <= math.MaxInt-limit/5 { + burst = limit + limit/5 + } else { + // If overflowed, set burst to math.MaxInt. + burst = math.MaxInt + } + return &storeWriteLimiter{ + limiters: make(map[uint64]*rate.Limiter), + limit: limit, + burst: burst, + } +} + +func (s *storeWriteLimiter) WaitN(ctx context.Context, storeID uint64, n int) error { + limiter := s.getLimiter(storeID) + // The original WaitN doesn't allow n > burst, + // so we call WaitN with burst multiple times. + for n > limiter.Burst() { + if err := limiter.WaitN(ctx, limiter.Burst()); err != nil { + return err + } + n -= limiter.Burst() + } + return limiter.WaitN(ctx, n) +} + +func (s *storeWriteLimiter) Limit() int { + return s.limit +} + +func (s *storeWriteLimiter) getLimiter(storeID uint64) *rate.Limiter { + s.rwm.RLock() + limiter, ok := s.limiters[storeID] + s.rwm.RUnlock() + if ok { + return limiter + } + s.rwm.Lock() + defer s.rwm.Unlock() + limiter, ok = s.limiters[storeID] + if !ok { + limiter = rate.NewLimiter(rate.Limit(s.limit), s.burst) + s.limiters[storeID] = limiter + } + return limiter +} + +type noopStoreWriteLimiter struct{} + +func (noopStoreWriteLimiter) WaitN(ctx context.Context, storeID uint64, n int) error { + return nil +} + +func (noopStoreWriteLimiter) Limit() int { + return math.MaxInt +} diff --git a/br/pkg/lightning/backend/local/localhelper_test.go b/br/pkg/lightning/backend/local/localhelper_test.go index 48ce64da5e3b6..767829e9c857f 100644 --- a/br/pkg/lightning/backend/local/localhelper_test.go +++ b/br/pkg/lightning/backend/local/localhelper_test.go @@ -770,3 +770,46 @@ func TestNeedSplit(t *testing.T) { } } } + +func TestStoreWriteLimiter(t *testing.T) { + // Test create store write limiter with limit math.MaxInt. + limiter := newStoreWriteLimiter(math.MaxInt) + err := limiter.WaitN(context.Background(), 1, 1024) + require.NoError(t, err) + + // Test WaitN exceeds the burst. + limiter = newStoreWriteLimiter(100) + start := time.Now() + // 120 is the initial burst, 150 is the number of new tokens. + err = limiter.WaitN(context.Background(), 1, 120+120) + require.NoError(t, err) + require.Greater(t, time.Since(start), time.Second) + + // Test WaitN with different store id. + limiter = newStoreWriteLimiter(100) + var wg sync.WaitGroup + ctx, cancel := context.WithTimeout(context.Background(), time.Second*2) + defer cancel() + for i := 0; i < 10; i++ { + wg.Add(1) + go func(storeID uint64) { + defer wg.Done() + start = time.Now() + var gotTokens int + for { + n := rand.Intn(50) + if limiter.WaitN(ctx, storeID, n) != nil { + break + } + gotTokens += n + } + elapsed := time.Since(start) + maxTokens := 120 + int(float64(elapsed)/float64(time.Second)*100) + // In theory, gotTokens should be less than or equal to maxTokens. + // But we allow a little of error to avoid the test being flaky. + require.LessOrEqual(t, gotTokens, maxTokens+1) + + }(uint64(i)) + } + wg.Wait() +} diff --git a/br/pkg/lightning/config/config.go b/br/pkg/lightning/config/config.go index fee2aaf29deb2..b0ffe32fa3cd5 100644 --- a/br/pkg/lightning/config/config.go +++ b/br/pkg/lightning/config/config.go @@ -532,6 +532,7 @@ type TikvImporter struct { EngineMemCacheSize ByteSize `toml:"engine-mem-cache-size" json:"engine-mem-cache-size"` LocalWriterMemCacheSize ByteSize `toml:"local-writer-mem-cache-size" json:"local-writer-mem-cache-size"` + StoreWriteBWLimit ByteSize `toml:"store-write-bwlimit" json:"store-write-bwlimit"` } type Checkpoint struct { diff --git a/br/tests/lightning_write_limit/config.toml b/br/tests/lightning_write_limit/config.toml new file mode 100644 index 0000000000000..e45e694126964 --- /dev/null +++ b/br/tests/lightning_write_limit/config.toml @@ -0,0 +1,5 @@ +[tikv-importer] +store-write-bwlimit = "1Mi" + +[mydumper.csv] +header = false diff --git a/br/tests/lightning_write_limit/run.sh b/br/tests/lightning_write_limit/run.sh new file mode 100644 index 0000000000000..b48d34e79a58d --- /dev/null +++ b/br/tests/lightning_write_limit/run.sh @@ -0,0 +1,49 @@ +#!/bin/bash +# +# Copyright 2022 PingCAP, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +set -eux + +mkdir -p "$TEST_DIR/data" + +cat <"$TEST_DIR/data/test-schema-create.sql" +CREATE DATABASE test; +EOF +cat <"$TEST_DIR/data/test.t-schema.sql" +CREATE TABLE test.t ( + id int, + a int, + b int, + c int +); +EOF + +# Generate 200k rows. Total size is about 5MiB. +set +x +for i in {1..200000}; do + echo "$i,$i,$i,$i" >>"$TEST_DIR/data/test.t.0.csv" +done +set -x + +start=$(date +%s) +run_lightning --backend local -d "$TEST_DIR/data" --config "tests/$TEST_NAME/config.toml" +end=$(date +%s) +take=$((end - start)) + +# The encoded kv size is 10MiB. Usually it should take more than 10s. +if [ $take -lt 10 ]; then + echo "Lightning runs too fast. The write limiter doesn't work." + exit 1 +fi diff --git a/br/tidb-lightning.toml b/br/tidb-lightning.toml index 8840eba06bb1d..a33eb46500104 100644 --- a/br/tidb-lightning.toml +++ b/br/tidb-lightning.toml @@ -136,6 +136,8 @@ addr = "127.0.0.1:8287" # The memory cache used in for local sorting during the encode-KV phase before flushing into the engines. The memory # usage is bound by region-concurrency * local-writer-mem-cache-size. #local-writer-mem-cache-size = '128MiB' +# Limit the write bandwidth to each tikv store. The unit is 'Bytes per second'. 0 means no limit. +#store-write-bwlimit = 0 [mydumper] # block size of file reading diff --git a/cmd/explaintest/r/collation_check_use_collation_disabled.result b/cmd/explaintest/r/collation_check_use_collation_disabled.result index 5b335fba0a59f..06af2890faa8f 100644 --- a/cmd/explaintest/r/collation_check_use_collation_disabled.result +++ b/cmd/explaintest/r/collation_check_use_collation_disabled.result @@ -128,4 +128,29 @@ select col_25 from tbl_2 where ( tbl_2.col_27 > 'nSWYrpTH' or not( tbl_2.col_27 col_25 select col_25 from tbl_2 use index(primary) where ( tbl_2.col_27 > 'nSWYrpTH' or not( tbl_2.col_27 between 'CsWIuxlSjU' and 'SfwoyjUEzgg' ) ) and ( tbl_2.col_23 <= -95); col_25 +drop table if exists t1; +drop table if exists t2; +create table t1(a char(20)); +create table t2(b binary(20), c binary(20)); +insert into t1 value('-1'); +insert into t2 value(0x2D31, 0x67); +insert into t2 value(0x2D31, 0x73); +select a from t1, t2 where t1.a between t2.b and t2.c; +a +select a from t1, t2 where cast(t1.a as binary(20)) between t2.b and t2.c; +a +-1 +-1 +drop table if exists t1; +drop table if exists t2; +create table t1(a char(20)) collate utf8mb4_general_ci; +create table t2(b binary(20), c char(20)) collate utf8mb4_general_ci; +insert into t1 values ('a'); +insert into t2 values (0x0, 'A'); +select * from t1, t2 where t1.a between t2.b and t2.c; +a b c +insert into t1 values ('-1'); +insert into t2 values (0x2d31, ''); +select * from t1, t2 where t1.a in (t2.b, 3); +a b c use test diff --git a/cmd/explaintest/r/collation_check_use_collation_enabled.result b/cmd/explaintest/r/collation_check_use_collation_enabled.result index 7aee634a8f0df..5bf70a6a73a09 100644 --- a/cmd/explaintest/r/collation_check_use_collation_enabled.result +++ b/cmd/explaintest/r/collation_check_use_collation_enabled.result @@ -147,4 +147,29 @@ col_25 select col_25 from tbl_2 use index(primary) where ( tbl_2.col_27 > 'nSWYrpTH' or not( tbl_2.col_27 between 'CsWIuxlSjU' and 'SfwoyjUEzgg' ) ) and ( tbl_2.col_23 <= -95); col_25 89 +drop table if exists t1; +drop table if exists t2; +create table t1(a char(20)); +create table t2(b binary(20), c binary(20)); +insert into t1 value('-1'); +insert into t2 value(0x2D31, 0x67); +insert into t2 value(0x2D31, 0x73); +select a from t1, t2 where t1.a between t2.b and t2.c; +a +select a from t1, t2 where cast(t1.a as binary(20)) between t2.b and t2.c; +a +-1 +-1 +drop table if exists t1; +drop table if exists t2; +create table t1(a char(20)) collate utf8mb4_general_ci; +create table t2(b binary(20), c char(20)) collate utf8mb4_general_ci; +insert into t1 values ('a'); +insert into t2 values (0x0, 'A'); +select * from t1, t2 where t1.a between t2.b and t2.c; +a b c +insert into t1 values ('-1'); +insert into t2 values (0x2d31, ''); +select * from t1, t2 where t1.a in (t2.b, 3); +a b c use test diff --git a/cmd/explaintest/t/collation_check_use_collation.test b/cmd/explaintest/t/collation_check_use_collation.test index 04d4642656de9..ebaa37588d153 100644 --- a/cmd/explaintest/t/collation_check_use_collation.test +++ b/cmd/explaintest/t/collation_check_use_collation.test @@ -86,5 +86,28 @@ insert ignore into tbl_2 values ( 5888267793391993829,5371,94.63,-109,5728076076 select col_25 from tbl_2 where ( tbl_2.col_27 > 'nSWYrpTH' or not( tbl_2.col_27 between 'CsWIuxlSjU' and 'SfwoyjUEzgg' ) ) and ( tbl_2.col_23 <= -95); select col_25 from tbl_2 use index(primary) where ( tbl_2.col_27 > 'nSWYrpTH' or not( tbl_2.col_27 between 'CsWIuxlSjU' and 'SfwoyjUEzgg' ) ) and ( tbl_2.col_23 <= -95); +# check implicit binary collation cast +drop table if exists t1; +drop table if exists t2; +# issue 34823 +create table t1(a char(20)); +create table t2(b binary(20), c binary(20)); +insert into t1 value('-1'); +insert into t2 value(0x2D31, 0x67); +insert into t2 value(0x2D31, 0x73); +select a from t1, t2 where t1.a between t2.b and t2.c; +select a from t1, t2 where cast(t1.a as binary(20)) between t2.b and t2.c; +# binary collation in single side +drop table if exists t1; +drop table if exists t2; +create table t1(a char(20)) collate utf8mb4_general_ci; +create table t2(b binary(20), c char(20)) collate utf8mb4_general_ci; +insert into t1 values ('a'); +insert into t2 values (0x0, 'A'); +select * from t1, t2 where t1.a between t2.b and t2.c; +insert into t1 values ('-1'); +insert into t2 values (0x2d31, ''); +select * from t1, t2 where t1.a in (t2.b, 3); + # cleanup environment use test diff --git a/errno/errcode.go b/errno/errcode.go index 0d80810516897..a26a1a1eaea6e 100644 --- a/errno/errcode.go +++ b/errno/errcode.go @@ -1023,6 +1023,8 @@ const ( ErrAssertionFailed = 8141 ErrInstanceScope = 8142 ErrNonTransactionalJobFailure = 8143 + ErrSettingNoopVariable = 8144 + ErrGettingNoopVariable = 8145 // Error codes used by TiDB ddl package ErrUnsupportedDDLOperation = 8200 diff --git a/errno/errname.go b/errno/errname.go index e14ddbe22ee2f..58866b7564cd0 100644 --- a/errno/errname.go +++ b/errno/errname.go @@ -1018,6 +1018,8 @@ var MySQLErrName = map[uint16]*mysql.ErrMessage{ ErrAssertionFailed: mysql.Message("assertion failed: key: %s, assertion: %s, start_ts: %v, existing start ts: %v, existing commit ts: %v", []int{0}), ErrInstanceScope: mysql.Message("modifying %s will require SET GLOBAL in a future version of TiDB", nil), ErrNonTransactionalJobFailure: mysql.Message("non-transactional job failed, job id: %d, total jobs: %d. job range: [%s, %s], job sql: %s, err: %v", []int{2, 3, 4}), + ErrSettingNoopVariable: mysql.Message("setting %s has no effect in TiDB", nil), + ErrGettingNoopVariable: mysql.Message("variable %s has no effect in TiDB", nil), ErrWarnOptimizerHintInvalidInteger: mysql.Message("integer value is out of range in '%s'", nil), ErrWarnOptimizerHintUnsupportedHint: mysql.Message("Optimizer hint %s is not supported by TiDB and is ignored", nil), diff --git a/errors.toml b/errors.toml index 6ccebe76a9b46..11518f664cda6 100755 --- a/errors.toml +++ b/errors.toml @@ -1456,6 +1456,11 @@ error = ''' modifying %s will require SET GLOBAL in a future version of TiDB ''' +["executor:8144"] +error = ''' +setting %s has no effect in TiDB +''' + ["executor:8212"] error = ''' Failed to split region ranges: %s @@ -2076,6 +2081,11 @@ error = ''' Column '%s' in ANALYZE column option does not exist in table '%s' ''' +["planner:8145"] +error = ''' +variable %s has no effect in TiDB +''' + ["planner:8242"] error = ''' '%s' is unsupported on cache tables. diff --git a/executor/errors.go b/executor/errors.go index 7551430e6901f..c65962f490f9c 100644 --- a/executor/errors.go +++ b/executor/errors.go @@ -55,6 +55,7 @@ var ( ErrInvalidSplitRegionRanges = dbterror.ClassExecutor.NewStd(mysql.ErrInvalidSplitRegionRanges) ErrViewInvalid = dbterror.ClassExecutor.NewStd(mysql.ErrViewInvalid) ErrInstanceScope = dbterror.ClassExecutor.NewStd(mysql.ErrInstanceScope) + ErrSettingNoopVariable = dbterror.ClassExecutor.NewStd(mysql.ErrSettingNoopVariable) ErrBRIEBackupFailed = dbterror.ClassExecutor.NewStd(mysql.ErrBRIEBackupFailed) ErrBRIERestoreFailed = dbterror.ClassExecutor.NewStd(mysql.ErrBRIERestoreFailed) diff --git a/executor/partition_table_test.go b/executor/partition_table_test.go index cccb827b0b030..3b8b50174ebc5 100644 --- a/executor/partition_table_test.go +++ b/executor/partition_table_test.go @@ -3468,8 +3468,8 @@ func TestPartitionTableExplain(t *testing.T) { "PartitionUnion 2.00 root ", "├─Batch_Point_Get 1.00 root table:t handle:[1 2], keep order:false, desc:false", "└─Batch_Point_Get 1.00 root table:t handle:[1 2], keep order:false, desc:false")) - tk.MustQuery(`explain format = 'brief' select * from t where a IN (2,3,4)`).Check(testkit.Rows("Batch_Point_Get 3.00 root table:t handle:[2 3 4], keep order:false, desc:false")) - tk.MustQuery(`explain format = 'brief' select * from t where a IN (2,3)`).Check(testkit.Rows("Batch_Point_Get 2.00 root table:t handle:[2 3], keep order:false, desc:false")) + tk.MustQuery(`explain format = 'brief' select * from t where a IN (2,3,4)`).Check(testkit.Rows("Batch_Point_Get 3.00 root table:t, partition:P0,p1,P2 handle:[2 3 4], keep order:false, desc:false")) + tk.MustQuery(`explain format = 'brief' select * from t where a IN (2,3)`).Check(testkit.Rows("Batch_Point_Get 2.00 root table:t, partition:P0,P2 handle:[2 3], keep order:false, desc:false")) // above ^^ is for completeness, the below vv is enough for Issue32719 tk.MustQuery(`explain format = 'brief' select * from t where b = 1`).Check(testkit.Rows( "PartitionUnion 1.00 root ", @@ -3553,8 +3553,8 @@ func TestPartitionTableExplain(t *testing.T) { tk.MustQuery(`explain format = 'brief' select * from t where a = 1 OR a = 2`).Check(testkit.Rows( "TableReader 2.00 root partition:p1,P2 data:TableRangeScan", "└─TableRangeScan 2.00 cop[tikv] table:t range:[1,1], [2,2], keep order:false")) - tk.MustQuery(`explain format = 'brief' select * from t where a IN (2,3,4)`).Check(testkit.Rows("Batch_Point_Get 3.00 root table:t handle:[2 3 4], keep order:false, desc:false")) - tk.MustQuery(`explain format = 'brief' select * from t where a IN (2,3)`).Check(testkit.Rows("Batch_Point_Get 2.00 root table:t handle:[2 3], keep order:false, desc:false")) + tk.MustQuery(`explain format = 'brief' select * from t where a IN (2,3,4)`).Check(testkit.Rows("Batch_Point_Get 3.00 root table:t, partition:P0,p1,P2 handle:[2 3 4], keep order:false, desc:false")) + tk.MustQuery(`explain format = 'brief' select * from t where a IN (2,3)`).Check(testkit.Rows("Batch_Point_Get 2.00 root table:t, partition:P0,P2 handle:[2 3], keep order:false, desc:false")) tk.MustQuery(`explain format = 'brief' select * from t where b = 1`).Check(testkit.Rows( "IndexReader 1.00 root partition:all index:IndexRangeScan", "└─IndexRangeScan 1.00 cop[tikv] table:t, index:b(b) range:[1,1], keep order:false")) diff --git a/executor/set.go b/executor/set.go index a79055abb5dbe..df0868e45f875 100644 --- a/executor/set.go +++ b/executor/set.go @@ -115,7 +115,12 @@ func (e *SetExecutor) setSysVariable(ctx context.Context, name string, v *expres } return variable.ErrUnknownSystemVar.GenWithStackByArgs(name) } - + if sysVar.IsNoop && !variable.EnableNoopVariables.Load() { + // The variable is a noop. For compatibility we allow it to still + // be changed, but we append a warning since users might be expecting + // something that's not going to happen. + sessionVars.StmtCtx.AppendWarning(ErrSettingNoopVariable.GenWithStackByArgs(sysVar.Name)) + } if sysVar.HasInstanceScope() && !v.IsGlobal && sessionVars.EnableLegacyInstanceScope { // For backward compatibility we will change the v.IsGlobal to true, // and append a warning saying this will not be supported in future. diff --git a/executor/set_test.go b/executor/set_test.go index 8408af7ce75e4..eb171e872d8c4 100644 --- a/executor/set_test.go +++ b/executor/set_test.go @@ -736,6 +736,31 @@ func TestSetVar(t *testing.T) { tk.MustQuery("select @@tidb_cost_model_version").Check(testkit.Rows("2")) } +func TestGetSetNoopVars(t *testing.T) { + store, clean := testkit.CreateMockStore(t) + defer clean() + tk := testkit.NewTestKit(t, store) + + // By default you can get/set noop sysvars without issue. + tk.MustQuery("SELECT @@query_cache_type").Check(testkit.Rows("OFF")) + tk.MustQuery("SHOW VARIABLES LIKE 'query_cache_type'").Check(testkit.Rows("query_cache_type OFF")) + tk.MustExec("SET query_cache_type=2") + tk.MustQuery("SELECT @@query_cache_type").Check(testkit.Rows("DEMAND")) + // When tidb_enable_noop_variables is OFF, you can GET in @@ context + // and always SET. But you can't see in SHOW VARIABLES. + // Warnings are also returned. + tk.MustExec("SET GLOBAL tidb_enable_noop_variables = OFF") + defer tk.MustExec("SET GLOBAL tidb_enable_noop_variables = ON") + tk.MustQuery("SELECT @@global.tidb_enable_noop_variables").Check(testkit.Rows("OFF")) + tk.MustQuery("SELECT @@query_cache_type").Check(testkit.Rows("DEMAND")) + tk.MustQuery("SHOW WARNINGS").Check(testkit.Rows("Warning 8145 variable query_cache_type has no effect in TiDB")) + tk.MustQuery("SHOW VARIABLES LIKE 'query_cache_type'").Check(testkit.Rows()) + tk.MustExec("SET query_cache_type = OFF") + tk.MustQuery("SHOW WARNINGS").Check(testkit.Rows("Warning 8144 setting query_cache_type has no effect in TiDB")) + // but the change is still effective. + tk.MustQuery("SELECT @@query_cache_type").Check(testkit.Rows("OFF")) +} + func TestTruncateIncorrectIntSessionVar(t *testing.T) { store, clean := testkit.CreateMockStore(t) defer clean() diff --git a/executor/show.go b/executor/show.go index 4b6d35d8ef187..9075444fd53f4 100644 --- a/executor/show.go +++ b/executor/show.go @@ -820,6 +820,9 @@ func (e *ShowExec) fetchShowVariables() (err error) { // otherwise, fetch the value from table `mysql.Global_Variables`. for _, v := range variable.GetSysVars() { if v.Scope != variable.ScopeSession { + if v.IsNoop && !variable.EnableNoopVariables.Load() { + continue + } if fieldFilter != "" && v.Name != fieldFilter { continue } else if fieldPatternsLike != nil && !fieldPatternsLike.DoMatch(v.Name) { @@ -842,6 +845,9 @@ func (e *ShowExec) fetchShowVariables() (err error) { // If it is a session only variable, use the default value defined in code, // otherwise, fetch the value from table `mysql.Global_Variables`. for _, v := range variable.GetSysVars() { + if v.IsNoop && !variable.EnableNoopVariables.Load() { + continue + } if fieldFilter != "" && v.Name != fieldFilter { continue } else if fieldPatternsLike != nil && !fieldPatternsLike.DoMatch(v.Name) { diff --git a/expression/builtin_cast.go b/expression/builtin_cast.go index d52b1aa8c3c20..c281e2de80302 100644 --- a/expression/builtin_cast.go +++ b/expression/builtin_cast.go @@ -29,6 +29,7 @@ import ( "github.com/pingcap/errors" "github.com/pingcap/tidb/parser/ast" + "github.com/pingcap/tidb/parser/charset" "github.com/pingcap/tidb/parser/model" "github.com/pingcap/tidb/parser/mysql" "github.com/pingcap/tidb/parser/terror" @@ -1840,6 +1841,15 @@ func BuildCastCollationFunction(ctx sessionctx.Context, expr Expression, ec *Exp } else { return expr } + } else if ec.Charset == charset.CharsetBin { + // When cast character string to binary string, if we still use fixed length representation, + // then 0 padding will be used, which can affect later execution. + // e.g. https://github.com/pingcap/tidb/issues/34823. + // On the other hand, we can not directly return origin expr back, + // since we need binary collation to do string comparison later. + // e.g. https://github.com/pingcap/tidb/pull/35053#discussion_r894155052 + // Here we use VarString type of cast, i.e `cast(a as binary)`, to avoid this problem. + tp.SetType(mysql.TypeVarString) } tp.SetCharset(ec.Charset) tp.SetCollate(ec.Collation) diff --git a/go.mod b/go.mod index 78704e7a8ecd7..bfd038156bf93 100644 --- a/go.mod +++ b/go.mod @@ -11,7 +11,7 @@ require ( github.com/Jeffail/gabs/v2 v2.5.1 github.com/Shopify/sarama v1.29.0 github.com/aws/aws-sdk-go v1.35.3 - github.com/blacktear23/go-proxyprotocol v0.0.0-20180807104634-af7a81e8dd0d + github.com/blacktear23/go-proxyprotocol v1.0.0 github.com/carlmjohnson/flagext v0.21.0 github.com/cheggaaa/pb/v3 v3.0.8 github.com/cheynewallace/tabby v1.1.1 diff --git a/go.sum b/go.sum index 66d46713ea35b..3a6e62f536235 100644 --- a/go.sum +++ b/go.sum @@ -119,8 +119,8 @@ github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM= github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw= github.com/bgentry/speakeasy v0.1.0/go.mod h1:+zsyZBPWlz7T6j88CTgSN5bM796AkVf0kBD4zp0CCIs= github.com/bketelsen/crypt v0.0.3-0.20200106085610-5cbc8cc4026c/go.mod h1:MKsuJmJgSg28kpZDP6UIiPt0e0Oz0kqKNGyRaWEPv84= -github.com/blacktear23/go-proxyprotocol v0.0.0-20180807104634-af7a81e8dd0d h1:rQlvB2AYWme2bIB18r/SipGiMEVJYE9U0z+MGoU/LtQ= -github.com/blacktear23/go-proxyprotocol v0.0.0-20180807104634-af7a81e8dd0d/go.mod h1:VKt7CNAQxpFpSDz3sXyj9hY/GbVsQCr0sB3w59nE7lU= +github.com/blacktear23/go-proxyprotocol v1.0.0 h1:WmMmtZanGEfIHnJN9N3A4Pl6mM69D+GxEph2eOaCf7g= +github.com/blacktear23/go-proxyprotocol v1.0.0/go.mod h1:fbqiWSHMxaW0KsJ3SHjpxOMbTpIaQSMRn1GRd+oPyEw= github.com/carlmjohnson/flagext v0.21.0 h1:/c4uK3ie786Z7caXLcIMvePNSSiH3bQVGDvmGLMme60= github.com/carlmjohnson/flagext v0.21.0/go.mod h1:Eenv0epIUAr4NuedNmkzI8WmBmjIxZC239XcKxYS2ac= github.com/census-instrumentation/opencensus-proto v0.2.1/go.mod h1:f6KPmirojxKA12rnyqOA5BBL4O983OfeGPqjHWSTneU= diff --git a/planner/core/errors.go b/planner/core/errors.go index 7182702e9d06a..84b92e39ef014 100644 --- a/planner/core/errors.go +++ b/planner/core/errors.go @@ -109,4 +109,5 @@ var ( ErrViewSelectTemporaryTable = dbterror.ClassOptimizer.NewStd(mysql.ErrViewSelectTmptable) ErrSubqueryMoreThan1Row = dbterror.ClassOptimizer.NewStd(mysql.ErrSubqueryNo1Row) ErrKeyPart0 = dbterror.ClassOptimizer.NewStd(mysql.ErrKeyPart0) + ErrGettingNoopVariable = dbterror.ClassOptimizer.NewStd(mysql.ErrGettingNoopVariable) ) diff --git a/planner/core/expression_rewriter.go b/planner/core/expression_rewriter.go index d96a161afa40c..f80ea20ad07c9 100644 --- a/planner/core/expression_rewriter.go +++ b/planner/core/expression_rewriter.go @@ -1297,6 +1297,10 @@ func (er *expressionRewriter) rewriteVariable(v *ast.VariableExpr) { } return } + if sysVar.IsNoop && !variable.EnableNoopVariables.Load() { + // The variable does nothing, append a warning to the statement output. + sessionVars.StmtCtx.AppendWarning(ErrGettingNoopVariable.GenWithStackByArgs(sysVar.Name)) + } if sem.IsEnabled() && sem.IsInvisibleSysVar(sysVar.Name) { err := ErrSpecificAccessDenied.GenWithStackByArgs("RESTRICTED_VARIABLES_ADMIN") er.b.visitInfo = appendDynamicVisitInfo(er.b.visitInfo, "RESTRICTED_VARIABLES_ADMIN", false, err) @@ -1557,6 +1561,7 @@ func (er *expressionRewriter) castCollationForIn(colLen int, elemCnt int, stkLen return } for i := stkLen - elemCnt; i < stkLen; i++ { + // todo: consider refining the code and reusing expression.BuildCollationFunction here if er.ctxStack[i].GetType().EvalType() == types.ETString { rowFunc, ok := er.ctxStack[i].(*expression.ScalarFunction) if ok && rowFunc.FuncName.String() == ast.RowFunc { @@ -1573,6 +1578,14 @@ func (er *expressionRewriter) castCollationForIn(colLen int, elemCnt int, stkLen } else { continue } + } else if coll.Charset == charset.CharsetBin { + // When cast character string to binary string, if we still use fixed length representation, + // then 0 padding will be used, which can affect later execution. + // e.g. https://github.com/pingcap/tidb/pull/35053#pullrequestreview-1008757770 gives an unexpected case. + // On the other hand, we can not directly return origin expr back, + // since we need binary collation to do string comparison later. + // Here we use VarString type of cast, i.e `cast(a as binary)`, to avoid this problem. + tp.SetType(mysql.TypeVarString) } tp.SetCharset(coll.Charset) tp.SetCollate(coll.Collation) diff --git a/planner/core/integration_test.go b/planner/core/integration_test.go index 87e9390ad1bde..c0890093b0080 100644 --- a/planner/core/integration_test.go +++ b/planner/core/integration_test.go @@ -6696,3 +6696,24 @@ func TestDecimalOverflow(t *testing.T) { tk.MustExec("insert into deci values (1234567890.123456789012345678901234567890,987654321098765432109876543210987654321098765432109876543210)") tk.MustQuery("select a from deci union ALL select b from deci;").Sort().Check(testkit.Rows("1234567890.123456789012345678901234567890", "99999999999999999999999999999999999.999999999999999999999999999999")) } + +func TestIssue35083(t *testing.T) { + defer func() { + variable.SetSysVar(variable.TiDBOptProjectionPushDown, variable.BoolToOnOff(config.GetGlobalConfig().Performance.ProjectionPushDown)) + }() + defer config.RestoreFunc()() + config.UpdateGlobal(func(conf *config.Config) { + conf.Performance.ProjectionPushDown = true + }) + variable.SetSysVar(variable.TiDBOptProjectionPushDown, variable.BoolToOnOff(config.GetGlobalConfig().Performance.ProjectionPushDown)) + store, clean := testkit.CreateMockStore(t) + defer clean() + tk := testkit.NewTestKit(t, store) + tk.MustExec("use test") + tk.MustExec("create table t1 (a varchar(100), b int)") + tk.MustQuery("select @@tidb_opt_projection_push_down").Check(testkit.Rows("1")) + tk.MustQuery("explain format = 'brief' select cast(a as datetime) from t1").Check(testkit.Rows( + "TableReader 10000.00 root data:Projection", + "└─Projection 10000.00 cop[tikv] cast(test.t1.a, datetime BINARY)->Column#4", + " └─TableFullScan 10000.00 cop[tikv] table:t1 keep order:false, stats:pseudo")) +} diff --git a/planner/core/partition_pruner_test.go b/planner/core/partition_pruner_test.go index 2e7d704c8663e..387396f25235c 100644 --- a/planner/core/partition_pruner_test.go +++ b/planner/core/partition_pruner_test.go @@ -704,6 +704,22 @@ func TestHashPartitionPruning(t *testing.T) { tk.MustQuery("SELECT col1, COL3 FROM t WHERE COL1 IN (0,14158354938390,0) AND COL3 IN (3522101843073676459,-2846203247576845955,838395691793635638);").Check(testkit.Rows("0 3522101843073676459")) } +func TestIssue32815(t *testing.T) { + store, clean := testkit.CreateMockStore(t) + defer clean() + + tk := testkit.NewTestKit(t, store) + tk.MustExec("set @@tidb_partition_prune_mode='dynamic'") + tk.MustExec("USE test;") + tk.MustExec("DROP TABLE IF EXISTS t;") + tk.MustExec("create table t (a int primary key, b int, key (b)) partition by hash(a) (partition P0, partition p1, partition P2)") + tk.MustExec("insert into t values (1, 1),(2, 2),(3, 3)") + tk.MustQuery("explain select * from t where a IN (1, 2)").Check(testkit.Rows( + "Batch_Point_Get_1 2.00 root table:t, partition:p1,P2 handle:[1 2], keep order:false, desc:false")) + tk.MustQuery("explain select * from t where a IN (1, 2, 1)").Check(testkit.Rows( + "Batch_Point_Get_1 3.00 root table:t, partition:p1,P2 handle:[1 2 1], keep order:false, desc:false")) +} + func TestIssue32007(t *testing.T) { store, clean := testkit.CreateMockStore(t) defer clean() diff --git a/planner/core/plan_cost.go b/planner/core/plan_cost.go index 6e194965901cf..ee0af71c51149 100644 --- a/planner/core/plan_cost.go +++ b/planner/core/plan_cost.go @@ -312,7 +312,12 @@ func (p *PhysicalTableReader) GetPlanCost(taskType property.TaskType, costFlag u concurrency = float64(p.ctx.GetSessionVars().DistSQLScanConcurrency()) rowSize = getTblStats(p.tablePlan).GetAvgRowSize(p.ctx, p.tablePlan.Schema().Columns, false, false) seekCost = estimateNetSeekCost(p.tablePlan) - childCost, err := p.tablePlan.GetPlanCost(property.CopSingleReadTaskType, costFlag) + tType := property.MppTaskType + if p.ctx.GetSessionVars().CostModelVersion == modelVer1 { + // regard the underlying tasks as cop-task on modelVer1 for compatibility + tType = property.CopSingleReadTaskType + } + childCost, err := p.tablePlan.GetPlanCost(tType, costFlag) if err != nil { return 0, err } @@ -326,7 +331,8 @@ func (p *PhysicalTableReader) GetPlanCost(taskType property.TaskType, costFlag u // consider concurrency p.planCost /= concurrency // consider tidb_enforce_mpp - if isMPP && p.ctx.GetSessionVars().IsMPPEnforced() { + if isMPP && p.ctx.GetSessionVars().IsMPPEnforced() && + !hasCostFlag(costFlag, CostFlagRecalculate) { // show the real cost in explain-statements p.planCost /= 1000000000 } } @@ -892,12 +898,19 @@ func (p *PhysicalHashJoin) GetPlanCost(taskType property.TaskType, costFlag uint } // GetCost computes cost of stream aggregation considering CPU/memory. -func (p *PhysicalStreamAgg) GetCost(inputRows float64, isRoot bool, costFlag uint64) float64 { +func (p *PhysicalStreamAgg) GetCost(inputRows float64, isRoot, isMPP bool, costFlag uint64) float64 { aggFuncFactor := p.getAggFuncCostFactor(false) var cpuCost float64 sessVars := p.ctx.GetSessionVars() if isRoot { cpuCost = inputRows * sessVars.GetCPUFactor() * aggFuncFactor + } else if isMPP { + if p.ctx.GetSessionVars().CostModelVersion == modelVer2 { + // use the dedicated CPU factor for TiFlash on modelVer2 + cpuCost = inputRows * sessVars.GetTiFlashCPUFactor() * aggFuncFactor + } else { + cpuCost = inputRows * sessVars.GetCopCPUFactor() * aggFuncFactor + } } else { cpuCost = inputRows * sessVars.GetCopCPUFactor() * aggFuncFactor } @@ -916,7 +929,7 @@ func (p *PhysicalStreamAgg) GetPlanCost(taskType property.TaskType, costFlag uin return 0, err } p.planCost = childCost - p.planCost += p.GetCost(getCardinality(p.children[0], costFlag), taskType == property.RootTaskType, costFlag) + p.planCost += p.GetCost(getCardinality(p.children[0], costFlag), taskType == property.RootTaskType, taskType == property.MppTaskType, costFlag) p.planCostInit = true return p.planCost, nil } @@ -936,6 +949,13 @@ func (p *PhysicalHashAgg) GetCost(inputRows float64, isRoot, isMPP bool, costFla // Cost of additional goroutines. cpuCost += (con + 1) * sessVars.GetConcurrencyFactor() } + } else if isMPP { + if p.ctx.GetSessionVars().CostModelVersion == modelVer2 { + // use the dedicated CPU factor for TiFlash on modelVer2 + cpuCost = inputRows * sessVars.GetTiFlashCPUFactor() * aggFuncFactor + } else { + cpuCost = inputRows * sessVars.GetCopCPUFactor() * aggFuncFactor + } } else { cpuCost = inputRows * sessVars.GetCopCPUFactor() * aggFuncFactor } @@ -1144,6 +1164,9 @@ func (p *PhysicalExchangeReceiver) GetPlanCost(taskType property.TaskType, costF } func getOperatorActRows(operator PhysicalPlan) float64 { + if operator == nil { + return 0 + } runtimeInfo := operator.SCtx().GetSessionVars().StmtCtx.RuntimeStatsColl id := operator.ID() actRows := 0.0 diff --git a/planner/core/point_get_plan.go b/planner/core/point_get_plan.go index afb34814f7196..01b6337ebc81f 100644 --- a/planner/core/point_get_plan.go +++ b/planner/core/point_get_plan.go @@ -270,6 +270,7 @@ type BatchPointGetPlan struct { dbName string TblInfo *model.TableInfo IndexInfo *model.IndexInfo + PartitionInfos []*model.PartitionDefinition Handles []kv.Handle HandleType *types.FieldType HandleParams []*expression.Constant // record all Parameters for Plan-Cache @@ -345,11 +346,25 @@ func (p *BatchPointGetPlan) ExplainNormalizedInfo() string { } // AccessObject implements physicalScan interface. -func (p *BatchPointGetPlan) AccessObject(_ bool) string { +func (p *BatchPointGetPlan) AccessObject(normalized bool) string { var buffer strings.Builder tblName := p.TblInfo.Name.O buffer.WriteString("table:") buffer.WriteString(tblName) + if p.PartitionInfos != nil { + if normalized { + buffer.WriteString(", partition:?") + } else { + for i, partitionInfo := range p.PartitionInfos { + if i == 0 { + buffer.WriteString(", partition:") + } else { + buffer.WriteString(",") + } + buffer.WriteString(partitionInfo.Name.O) + } + } + } if p.IndexInfo != nil { if p.IndexInfo.Primary && p.TblInfo.IsCommonHandle { buffer.WriteString(", clustered index:" + p.IndexInfo.Name.O + "(") @@ -565,9 +580,13 @@ func newBatchPointGetPlan( return nil } } + if handleCol != nil { + // condition key of where is primary key var handles = make([]kv.Handle, len(patternInExpr.List)) var handleParams = make([]*expression.Constant, len(patternInExpr.List)) + var pos2PartitionDefinition = make(map[int]*model.PartitionDefinition) + partitionInfos := make([]*model.PartitionDefinition, 0, len(patternInExpr.List)) for i, item := range patternInExpr.List { // SELECT * FROM t WHERE (key) in ((1), (2)) if p, ok := item.(*ast.ParenthesesExpr); ok { @@ -600,13 +619,39 @@ func newBatchPointGetPlan( } handles[i] = kv.IntHandle(intDatum.GetInt64()) handleParams[i] = con + pairs := []nameValuePair{{colName: handleCol.Name.L, colFieldType: item.GetType(), value: *intDatum, con: con}} + if tbl.GetPartitionInfo() != nil { + tmpPartitionDefinition, _, pos, isTableDual := getPartitionInfo(ctx, tbl, pairs) + if isTableDual { + return nil + } + if tmpPartitionDefinition != nil { + pos2PartitionDefinition[pos] = tmpPartitionDefinition + } + } + } + + posArr := make([]int, len(pos2PartitionDefinition)) + i := 0 + for pos := range pos2PartitionDefinition { + posArr[i] = pos + i++ + } + sort.Ints(posArr) + for _, pos := range posArr { + partitionInfos = append(partitionInfos, pos2PartitionDefinition[pos]) } + if len(partitionInfos) == 0 { + partitionInfos = nil + } + return BatchPointGetPlan{ - TblInfo: tbl, - Handles: handles, - HandleParams: handleParams, - HandleType: &handleCol.FieldType, - PartitionExpr: partitionExpr, + TblInfo: tbl, + Handles: handles, + HandleParams: handleParams, + HandleType: &handleCol.FieldType, + PartitionExpr: partitionExpr, + PartitionInfos: partitionInfos, }.Init(ctx, statsInfo, schema, names, 0) } @@ -661,14 +706,18 @@ func newBatchPointGetPlan( indexValues := make([][]types.Datum, len(patternInExpr.List)) indexValueParams := make([][]*expression.Constant, len(patternInExpr.List)) + partitionInfos := make([]*model.PartitionDefinition, 0, len(patternInExpr.List)) + var pos2PartitionDefinition = make(map[int]*model.PartitionDefinition) + var indexTypes []*types.FieldType for i, item := range patternInExpr.List { - // SELECT * FROM t WHERE (key) in ((1), (2)) + // SELECT * FROM t WHERE (key) in ((1), (2)) or SELECT * FROM t WHERE (key1, key2) in ((1, 1), (2, 2)) if p, ok := item.(*ast.ParenthesesExpr); ok { item = p.Expr } var values []types.Datum var valuesParams []*expression.Constant + var pairs []nameValuePair switch x := item.(type) { case *ast.RowExpr: // The `len(values) == len(valuesParams)` should be satisfied in this mode @@ -676,6 +725,7 @@ func newBatchPointGetPlan( return nil } values = make([]types.Datum, len(x.Values)) + pairs = make([]nameValuePair, 0, len(x.Values)) valuesParams = make([]*expression.Constant, len(x.Values)) initTypes := false if indexTypes == nil { // only init once @@ -683,6 +733,7 @@ func newBatchPointGetPlan( initTypes = true } for index, inner := range x.Values { + // permutations is used to match column and value. permIndex := permutations[index] switch innerX := inner.(type) { case *driver.ValueExpr: @@ -691,6 +742,7 @@ func newBatchPointGetPlan( return nil } values[permIndex] = innerX.Datum + pairs = append(pairs, nameValuePair{colName: whereColNames[index], value: innerX.Datum}) case *driver.ParamMarkerExpr: con, err := expression.ParamMarkerExpression(ctx, innerX, true) if err != nil { @@ -709,6 +761,7 @@ func newBatchPointGetPlan( if initTypes { indexTypes[permIndex] = &colInfos[index].FieldType } + pairs = append(pairs, nameValuePair{colName: whereColNames[index], value: innerX.Datum}) default: return nil } @@ -724,6 +777,8 @@ func newBatchPointGetPlan( return nil } values = []types.Datum{*dval} + valuesParams = []*expression.Constant{nil} + pairs = append(pairs, nameValuePair{colName: whereColNames[0], value: *dval}) case *driver.ParamMarkerExpr: if len(whereColNames) != 1 { return nil @@ -745,12 +800,39 @@ func newBatchPointGetPlan( if indexTypes == nil { // only init once indexTypes = []*types.FieldType{&colInfos[0].FieldType} } + pairs = append(pairs, nameValuePair{colName: whereColNames[0], value: *dval}) + default: return nil } indexValues[i] = values indexValueParams[i] = valuesParams + if tbl.GetPartitionInfo() != nil { + tmpPartitionDefinition, _, pos, isTableDual := getPartitionInfo(ctx, tbl, pairs) + if isTableDual { + return nil + } + if tmpPartitionDefinition != nil { + pos2PartitionDefinition[pos] = tmpPartitionDefinition + } + } + + } + + posArr := make([]int, len(pos2PartitionDefinition)) + i := 0 + for pos := range pos2PartitionDefinition { + posArr[i] = pos + i++ } + sort.Ints(posArr) + for _, pos := range posArr { + partitionInfos = append(partitionInfos, pos2PartitionDefinition[pos]) + } + if len(partitionInfos) == 0 { + partitionInfos = nil + } + return BatchPointGetPlan{ TblInfo: tbl, IndexInfo: matchIdxInfo, @@ -759,6 +841,7 @@ func newBatchPointGetPlan( IndexColTypes: indexTypes, PartitionColPos: pos, PartitionExpr: partitionExpr, + PartitionInfos: partitionInfos, }.Init(ctx, statsInfo, schema, names, 0) } @@ -768,6 +851,8 @@ func tryWhereIn2BatchPointGet(ctx sessionctx.Context, selStmt *ast.SelectStmt) * len(selStmt.WindowSpecs) > 0 { return nil } + // `expr1 in (1, 2) and expr2 in (1, 2)` isn't PatternInExpr, so it can't use tryWhereIn2BatchPointGet. + // (expr1, expr2) in ((1, 1), (2, 2)) can hit it. in, ok := selStmt.Where.(*ast.PatternInExpr) if !ok || in.Not || len(in.List) < 1 { return nil @@ -907,7 +992,7 @@ func tryPointGetPlan(ctx sessionctx.Context, selStmt *ast.SelectStmt, check bool var partitionInfo *model.PartitionDefinition var pos int if pi != nil { - partitionInfo, pos, isTableDual = getPartitionInfo(ctx, tbl, pairs) + partitionInfo, pos, _, isTableDual = getPartitionInfo(ctx, tbl, pairs) if isTableDual { p := newPointGetPlan(ctx, tblName.Schema.O, schema, tbl, names) p.IsTableDual = true @@ -1583,15 +1668,15 @@ func buildHandleCols(ctx sessionctx.Context, tbl *model.TableInfo, schema *expre return &IntHandleCols{col: handleCol} } -func getPartitionInfo(ctx sessionctx.Context, tbl *model.TableInfo, pairs []nameValuePair) (*model.PartitionDefinition, int, bool) { +func getPartitionInfo(ctx sessionctx.Context, tbl *model.TableInfo, pairs []nameValuePair) (*model.PartitionDefinition, int, int, bool) { partitionExpr := getPartitionExpr(ctx, tbl) if partitionExpr == nil { - return nil, 0, false + return nil, 0, 0, false } pi := tbl.GetPartitionInfo() if pi == nil { - return nil, 0, false + return nil, 0, 0, false } switch pi.Type { @@ -1599,19 +1684,19 @@ func getPartitionInfo(ctx sessionctx.Context, tbl *model.TableInfo, pairs []name expr := partitionExpr.OrigExpr col, ok := expr.(*ast.ColumnNameExpr) if !ok { - return nil, 0, false + return nil, 0, 0, false } partitionColName := col.Name if partitionColName == nil { - return nil, 0, false + return nil, 0, 0, false } for i, pair := range pairs { if partitionColName.Name.L == pair.colName { val := pair.value.GetInt64() pos := mathutil.Abs(val % int64(pi.Num)) - return &pi.Definitions[pos], i, false + return &pi.Definitions[pos], i, int(pos), false } } case model.PartitionTypeRange: @@ -1629,9 +1714,9 @@ func getPartitionInfo(ctx sessionctx.Context, tbl *model.TableInfo, pairs []name return ranges.Compare(i, val, unsigned) > 0 }) if pos >= 0 && pos < length { - return &pi.Definitions[pos], i, false + return &pi.Definitions[pos], i, pos, false } - return nil, 0, true + return nil, 0, 0, true } } } @@ -1648,15 +1733,15 @@ func getPartitionInfo(ctx sessionctx.Context, tbl *model.TableInfo, pairs []name isNull := false pos := partitionExpr.ForListPruning.LocatePartition(val, isNull) if pos >= 0 { - return &pi.Definitions[pos], i, false + return &pi.Definitions[pos], i, pos, false } - return nil, 0, true + return nil, 0, 0, true } } } } } - return nil, 0, false + return nil, 0, 0, false } func findPartitionIdx(idxInfo *model.IndexInfo, pos int, pairs []nameValuePair) int { diff --git a/planner/core/point_get_plan_test.go b/planner/core/point_get_plan_test.go index d39e95b767ab8..508e21004476d 100644 --- a/planner/core/point_get_plan_test.go +++ b/planner/core/point_get_plan_test.go @@ -662,18 +662,18 @@ func TestBatchPointGetPartition(t *testing.T) { tk.MustExec("create table t(a int primary key, b int) PARTITION BY HASH(a) PARTITIONS 4") tk.MustExec("insert into t values (1, 1), (2, 2), (3, 3), (4, 4)") tk.MustQuery("explain format = 'brief' select * from t where a in (1, 2, 3, 4)").Check(testkit.Rows( - "Batch_Point_Get 4.00 root table:t handle:[1 2 3 4], keep order:false, desc:false", + "Batch_Point_Get 4.00 root table:t, partition:p0,p1,p2,p3 handle:[1 2 3 4], keep order:false, desc:false", )) tk.MustQuery("select * from t where a in (1, 2, 3, 4)").Check(testkit.Rows("1 1", "2 2", "3 3", "4 4")) tk.MustQuery("explain format = 'brief' update t set b = b + 1 where a in (1, 2, 3, 4)").Check(testkit.Rows( - "Update N/A root N/A]\n[└─Batch_Point_Get 4.00 root table:t handle:[1 2 3 4], keep order:false, desc:false", + "Update N/A root N/A]\n[└─Batch_Point_Get 4.00 root table:t, partition:p0,p1,p2,p3 handle:[1 2 3 4], keep order:false, desc:false", )) tk.MustExec("update t set b = b + 1 where a in (1, 2, 3, 4)") tk.MustQuery("select * from t where a in (1, 2, 3, 4)").Check(testkit.Rows("1 2", "2 3", "3 4", "4 5")) tk.MustQuery("explain format = 'brief' delete from t where a in (1, 2, 3, 4)").Check(testkit.Rows( - "Delete N/A root N/A]\n[└─Batch_Point_Get 4.00 root table:t handle:[1 2 3 4], keep order:false, desc:false", + "Delete N/A root N/A]\n[└─Batch_Point_Get 4.00 root table:t, partition:p0,p1,p2,p3 handle:[1 2 3 4], keep order:false, desc:false", )) tk.MustExec("delete from t where a in (1, 2, 3, 4)") tk.MustQuery("select * from t where a in (1, 2, 3, 4)").Check(testkit.Rows()) @@ -681,26 +681,88 @@ func TestBatchPointGetPartition(t *testing.T) { tk.MustExec("drop table t") tk.MustExec("create table t(a int, b int, c int, primary key (a, b)) PARTITION BY HASH(a) PARTITIONS 4") tk.MustExec("insert into t values (1, 1, 1), (2, 2, 2), (3, 3, 3), (4, 4, 4)") + tk.MustQuery("explain format = 'brief' select * from t where a = 1 and b = 1").Check(testkit.Rows("Point_Get 1.00 root table:t, partition:p1, clustered index:PRIMARY(a, b) ")) + tk.MustQuery("explain format = 'brief' select * from t where (a, b) in ((1, 1), (2, 2), (3, 3), (4, 4))").Check(testkit.Rows( - "Batch_Point_Get 4.00 root table:t, clustered index:PRIMARY(a, b) keep order:false, desc:false", + "Batch_Point_Get 4.00 root table:t, partition:p0,p1,p2,p3, clustered index:PRIMARY(a, b) keep order:false, desc:false", )) tk.MustQuery("select * from t where (a, b) in ((1, 1), (2, 2), (3, 3), (4, 4))"). Check(testkit.Rows("1 1 1", "2 2 2", "3 3 3", "4 4 4")) tk.MustQuery("explain format = 'brief' update t set c = c + 1 where (a,b) in ((1,1),(2,2),(3,3),(4,4))").Check(testkit.Rows( - "Update N/A root N/A]\n[└─Batch_Point_Get 4.00 root table:t, clustered index:PRIMARY(a, b) keep order:false, desc:false", + "Update N/A root N/A]\n[└─Batch_Point_Get 4.00 root table:t, partition:p0,p1,p2,p3, clustered index:PRIMARY(a, b) keep order:false, desc:false", )) tk.MustExec("update t set c = c + 1 where (a,b) in ((1,1),(2,2),(3,3),(4,4))") tk.MustQuery("select * from t where (a, b) in ((1, 1), (2, 2), (3, 3), (4, 4))").Sort(). Check(testkit.Rows("1 1 2", "2 2 3", "3 3 4", "4 4 5")) tk.MustQuery("explain format = 'brief' delete from t where (a,b) in ((1,1),(2,2),(3,3),(4,4))").Check(testkit.Rows( - "Delete N/A root N/A]\n[└─Batch_Point_Get 4.00 root table:t, clustered index:PRIMARY(a, b) keep order:false, desc:false", + "Delete N/A root N/A]\n[└─Batch_Point_Get 4.00 root table:t, partition:p0,p1,p2,p3, clustered index:PRIMARY(a, b) keep order:false, desc:false", )) tk.MustExec("delete from t where (a,b) in ((1,1),(2,2),(3,3),(4,4))") tk.MustQuery("select * from t where (a, b) in ((1, 1), (2, 2), (3, 3), (4, 4))").Check(testkit.Rows()) } +func TestBatchPointGetPartitionForAccessObject(t *testing.T) { + store, clean := testkit.CreateMockStore(t) + defer clean() + + tk := testkit.NewTestKit(t, store) + tk.MustExec("set @@tidb_partition_prune_mode = 'dynamic'") + tk.MustExec("use test") + tk.Session().GetSessionVars().EnableClusteredIndex = variable.ClusteredIndexDefModeOn + tk.MustExec("drop table if exists t") + tk.MustExec("create table t(a int, b int, UNIQUE KEY (b)) PARTITION BY HASH(b) PARTITIONS 4") + tk.MustExec("insert into t values(1, 1), (2, 2), (3, 3), (4, 4)") + tk.MustQuery("explain select * from t where b in (1, 2)").Check(testkit.Rows( + "Batch_Point_Get_1 2.00 root table:t, partition:p1,p2, index:b(b) keep order:false, desc:false")) + tk.MustQuery("explain select * from t where b in (1, 2, 1)").Check(testkit.Rows( + "Batch_Point_Get_1 3.00 root table:t, partition:p1,p2, index:b(b) keep order:false, desc:false")) + + tk.MustExec("set @@session.tidb_enable_list_partition = ON") + tk.MustExec("drop table if exists t") + tk.MustExec("CREATE TABLE t (id int primary key, name_id int) PARTITION BY LIST(id) (" + + "partition p0 values IN (1, 2), " + + "partition p1 values IN (3, 4), " + + "partition p3 values IN (5))") + tk.MustExec("insert into t values(1, 1), (2, 2), (3, 3), (4, 4)") + tk.MustQuery("explain format='brief' select * from t where id in (1, 3)").Check(testkit.Rows( + "Batch_Point_Get 2.00 root table:t, partition:p0,p1 handle:[1 3], keep order:false, desc:false")) + + tk.MustExec("set @@session.tidb_enable_list_partition = ON") + tk.MustExec("drop table if exists t0") + tk.MustExec("CREATE TABLE t0 (id int primary key, name_id int) PARTITION BY LIST COLUMNS(id) (" + + "partition p0 values IN (1, 2), " + + "partition p1 values IN (3, 4), " + + "partition p3 values IN (5))") + tk.MustExec("insert into t0 values(1, 1), (2, 2), (3, 3), (4, 4)") + tk.MustQuery("explain format='brief' select * from t0 where id in (1, 3)").Check(testkit.Rows( + "TableReader 2.00 root partition:p0,p1 data:TableRangeScan]\n" + + "[└─TableRangeScan 2.00 cop[tikv] table:t0 range:[1,1], [3,3], keep order:false, stats:pseudo")) + + tk.MustExec("set @@session.tidb_enable_list_partition = ON") + tk.MustExec("drop table if exists t1") + tk.MustExec("CREATE TABLE t1 (id int, name_id int, unique key(id, name_id)) PARTITION BY LIST COLUMNS(id, name_id) (" + + "partition p0 values IN ((1, 1),(2, 2)), " + + "partition p1 values IN ((3, 3),(4, 4)), " + + "partition p3 values IN ((5, 5)))") + tk.MustExec("insert into t1 values(1, 1), (2, 2), (3, 3), (4, 4)") + tk.MustQuery("explain format='brief' select * from t1 where (id, name_id) in ((1, 1), (3, 3))").Check(testkit.Rows( + "IndexReader 2.00 root partition:p0,p1 index:IndexRangeScan]\n" + + "[└─IndexRangeScan 2.00 cop[tikv] table:t1, index:id(id, name_id) range:[1 1,1 1], [3 3,3 3], keep order:false, stats:pseudo")) + + tk.MustExec("set @@session.tidb_enable_list_partition = ON") + tk.MustExec("drop table if exists t2") + tk.MustExec("CREATE TABLE t2 (id int, name varchar(10), unique key(id, name)) PARTITION BY LIST COLUMNS(id, name) (" + + "partition p0 values IN ((1,'a'),(2,'b')), " + + "partition p1 values IN ((3,'c'),(4,'d')), " + + "partition p3 values IN ((5,'e')))") + tk.MustExec("insert into t2 values(1, 'a'), (2, 'b'), (3, 'c'), (4, 'd')") + tk.MustQuery("explain format='brief' select * from t2 where (id, name) in ((1, 'a'), (3, 'c'))").Check(testkit.Rows( + "IndexReader 2.00 root partition:p0,p1 index:IndexRangeScan]\n" + + "[└─IndexRangeScan 2.00 cop[tikv] table:t2, index:id(id, name) range:[1 \"a\",1 \"a\"], [3 \"c\",3 \"c\"], keep order:false, stats:pseudo")) +} + func TestIssue19141(t *testing.T) { // For issue 19141, fix partition selection on batch point get. store, clean := testkit.CreateMockStore(t) diff --git a/planner/core/task.go b/planner/core/task.go index f197317eb836e..fd6cac675f4c6 100644 --- a/planner/core/task.go +++ b/planner/core/task.go @@ -1665,7 +1665,7 @@ func (p *PhysicalStreamAgg) attach2Task(tasks ...task) task { partialAgg.SetChildren(cop.indexPlan) cop.indexPlan = partialAgg } - cop.addCost(partialAgg.(*PhysicalStreamAgg).GetCost(inputRows, false, 0)) + cop.addCost(partialAgg.(*PhysicalStreamAgg).GetCost(inputRows, false, false, 0)) partialAgg.SetCost(cop.cost()) } t = cop.convertToRootTask(p.ctx) @@ -1678,7 +1678,7 @@ func (p *PhysicalStreamAgg) attach2Task(tasks ...task) task { } else { attachPlan2Task(p, t) } - t.addCost(final.GetCost(inputRows, true, 0)) + t.addCost(final.GetCost(inputRows, true, false, 0)) t.plan().SetCost(t.cost()) return t } diff --git a/session/bootstrap.go b/session/bootstrap.go index 960ca9fa3a625..68e97c84d1dd2 100644 --- a/session/bootstrap.go +++ b/session/bootstrap.go @@ -1971,6 +1971,12 @@ func doDDLWorks(s Session) { mustExecute(s, CreateAdvisoryLocks) } +// inTestSuite checks if we are bootstrapping in the context of tests. +// There are some historical differences in behavior between tests and non-tests. +func inTestSuite() bool { + return flag.Lookup("test.v") != nil || flag.Lookup("check.v") != nil +} + // doDMLWorks executes DML statements in bootstrap stage. // All the statements run in a single transaction. // TODO: sanitize. @@ -1991,58 +1997,55 @@ func doDMLWorks(s Session) { ("%", "root", "", "mysql_native_password", "Y", "Y", "Y", "Y", "Y", "Y", "Y", "Y", "Y", "Y", "Y", "Y", "Y", "Y", "Y", "Y", "Y", "Y", "Y", "Y", "Y", "Y", "Y", "Y", "Y", "Y", "Y", "N", "Y", "Y", "Y", "Y", "Y")`) } - // Init global system variables table. + // For GLOBAL scoped system variables, insert the initial value + // into the mysql.global_variables table. This is only run on initial + // bootstrap, and in some cases we will use a different default value + // for new installs versus existing installs. + values := make([]string, 0, len(variable.GetSysVars())) for k, v := range variable.GetSysVars() { - // Only global variables should be inserted. - if v.HasGlobalScope() { - vVal := v.Value - if v.Name == variable.TiDBTxnMode && config.GetGlobalConfig().Store == "tikv" { + if !v.HasGlobalScope() { + continue + } + vVal := v.Value + switch v.Name { + case variable.TiDBTxnMode: + if config.GetGlobalConfig().Store == "tikv" { vVal = "pessimistic" } - if v.Name == variable.TiDBRowFormatVersion { - vVal = strconv.Itoa(variable.DefTiDBRowFormatV2) - } - if v.Name == variable.TiDBPartitionPruneMode { - vVal = variable.DefTiDBPartitionPruneMode - if flag.Lookup("test.v") != nil || flag.Lookup("check.v") != nil || config.CheckTableBeforeDrop { - // enable Dynamic Prune by default in test case. - vVal = string(variable.Dynamic) - } - } - if v.Name == variable.TiDBMemOOMAction { - if flag.Lookup("test.v") != nil || flag.Lookup("check.v") != nil { - // Change the OOM action to log for the test suite. - vVal = variable.OOMActionLog - } - } - if v.Name == variable.TiDBEnableChangeMultiSchema { - vVal = variable.Off - if flag.Lookup("test.v") != nil || flag.Lookup("check.v") != nil { - // enable change multi schema in test case for compatibility with old cases. - vVal = variable.On - } - } - if v.Name == variable.TiDBEnableAsyncCommit && config.GetGlobalConfig().Store == "tikv" { + case variable.TiDBEnableAsyncCommit, variable.TiDBEnable1PC: + if config.GetGlobalConfig().Store == "tikv" { vVal = variable.On } - if v.Name == variable.TiDBEnable1PC && config.GetGlobalConfig().Store == "tikv" { - vVal = variable.On + case variable.TiDBPartitionPruneMode: + if inTestSuite() || config.CheckTableBeforeDrop { + vVal = string(variable.Dynamic) } - if v.Name == variable.TiDBEnableMutationChecker { + case variable.TiDBEnableChangeMultiSchema: + if inTestSuite() { vVal = variable.On } - if v.Name == variable.TiDBEnableAutoAnalyze { - if flag.Lookup("test.v") != nil || flag.Lookup("check.v") != nil { - vVal = variable.Off - } + case variable.TiDBMemOOMAction: + if inTestSuite() { + vVal = variable.OOMActionLog } - if v.Name == variable.TiDBTxnAssertionLevel { - vVal = variable.AssertionFastStr + case variable.TiDBEnableAutoAnalyze: + if inTestSuite() { + vVal = variable.Off } - value := fmt.Sprintf(`("%s", "%s")`, strings.ToLower(k), vVal) - values = append(values, value) + // For the following sysvars, we change the default + // FOR NEW INSTALLS ONLY. In most cases you don't want to do this. + // It is better to change the value in the Sysvar struct, so that + // all installs will have the same value. + case variable.TiDBRowFormatVersion: + vVal = strconv.Itoa(variable.DefTiDBRowFormatV2) + case variable.TiDBTxnAssertionLevel: + vVal = variable.AssertionFastStr + case variable.TiDBEnableMutationChecker: + vVal = variable.On } + value := fmt.Sprintf(`("%s", "%s")`, strings.ToLower(k), vVal) + values = append(values, value) } sql := fmt.Sprintf("INSERT HIGH_PRIORITY INTO %s.%s VALUES %s;", mysql.SystemDB, mysql.GlobalVariablesTable, strings.Join(values, ", ")) diff --git a/session/session.go b/session/session.go index bb1f15836fd90..3f163146e8796 100644 --- a/session/session.go +++ b/session/session.go @@ -56,6 +56,7 @@ import ( "github.com/pingcap/tidb/table" "github.com/pingcap/tidb/table/temptable" "github.com/pingcap/tidb/util/logutil/consistency" + "github.com/pingcap/tidb/util/sem" "github.com/pingcap/tidb/util/topsql" topsqlstate "github.com/pingcap/tidb/util/topsql/state" "github.com/pingcap/tidb/util/topsql/stmtstats" @@ -3496,10 +3497,45 @@ func (s *session) GetStmtStats() *stmtstats.StatementStats { // EncodeSessionStates implements SessionStatesHandler.EncodeSessionStates interface. func (s *session) EncodeSessionStates(ctx context.Context, sctx sessionctx.Context, sessionStates *sessionstates.SessionStates) (err error) { - return s.sessionVars.EncodeSessionStates(ctx, sessionStates) + if err = s.sessionVars.EncodeSessionStates(ctx, sessionStates); err != nil { + return err + } + + // Encode session variables. We put it here instead of SessionVars to avoid cycle import. + sessionStates.SystemVars = make(map[string]string) + for _, sv := range variable.GetSysVars() { + switch { + case sv.Hidden, sv.HasNoneScope(), sv.HasInstanceScope(), !sv.HasSessionScope(): + // Hidden and none-scoped variables cannot be modified. + // Instance-scoped variables don't need to be encoded. + // Noop variables should also be migrated even if they are noop. + continue + case sv.ReadOnly: + // Skip read-only variables here. We encode them into SessionStates manually. + continue + case sem.IsEnabled() && sem.IsInvisibleSysVar(sv.Name): + // If they are shown, there will be a security issue. + continue + } + // Get all session variables because the default values may change between versions. + if val, keep, err := variable.GetSessionStatesSystemVar(s.sessionVars, sv.Name); err == nil && keep { + sessionStates.SystemVars[sv.Name] = val + } + } + return } // DecodeSessionStates implements SessionStatesHandler.DecodeSessionStates interface. func (s *session) DecodeSessionStates(ctx context.Context, sctx sessionctx.Context, sessionStates *sessionstates.SessionStates) (err error) { - return s.sessionVars.DecodeSessionStates(ctx, sessionStates) + if err = s.sessionVars.DecodeSessionStates(ctx, sessionStates); err != nil { + return err + } + + // Decode session variables. + for name, val := range sessionStates.SystemVars { + if err = variable.SetSessionSystemVar(s.sessionVars, name, val); err != nil { + return err + } + } + return err } diff --git a/sessionctx/sessionstates/session_states.go b/sessionctx/sessionstates/session_states.go index 43adb554f5758..312cf891ec80e 100644 --- a/sessionctx/sessionstates/session_states.go +++ b/sessionctx/sessionstates/session_states.go @@ -24,4 +24,5 @@ import ( type SessionStates struct { UserVars map[string]*types.Datum `json:"user-var-values,omitempty"` UserVarTypes map[string]*ptypes.FieldType `json:"user-var-types,omitempty"` + SystemVars map[string]string `json:"sys-vars,omitempty"` } diff --git a/sessionctx/sessionstates/session_states_test.go b/sessionctx/sessionstates/session_states_test.go index 61413039f29f1..81e4cb6d5285a 100644 --- a/sessionctx/sessionstates/session_states_test.go +++ b/sessionctx/sessionstates/session_states_test.go @@ -16,11 +16,14 @@ package sessionstates_test import ( "fmt" + "strconv" "strings" "testing" "github.com/pingcap/tidb/errno" + "github.com/pingcap/tidb/sessionctx/variable" "github.com/pingcap/tidb/testkit" + "github.com/pingcap/tidb/util/sem" "github.com/stretchr/testify/require" ) @@ -80,12 +83,134 @@ func TestUserVars(t *testing.T) { } } +func TestSystemVars(t *testing.T) { + store, clean := testkit.CreateMockStore(t) + defer clean() + + tests := []struct { + stmts []string + varName string + inSessionStates bool + checkStmt string + expectedValue string + }{ + { + // normal variable + inSessionStates: true, + varName: variable.TiDBMaxTiFlashThreads, + expectedValue: strconv.Itoa(variable.DefTiFlashMaxThreads), + }, + { + // hidden variable + inSessionStates: false, + varName: variable.TiDBTxnReadTS, + }, + { + // none-scoped variable + inSessionStates: false, + varName: variable.DataDir, + expectedValue: "/usr/local/mysql/data/", + }, + { + // instance-scoped variable + inSessionStates: false, + varName: variable.TiDBGeneralLog, + expectedValue: "0", + }, + { + // global-scoped variable + inSessionStates: false, + varName: variable.TiDBAutoAnalyzeStartTime, + expectedValue: variable.DefAutoAnalyzeStartTime, + }, + { + // sem invisible variable + inSessionStates: false, + varName: variable.TiDBAllowRemoveAutoInc, + }, + { + // noop variables + stmts: []string{"set sql_buffer_result=true"}, + inSessionStates: true, + varName: "sql_buffer_result", + expectedValue: "1", + }, + { + stmts: []string{"set transaction isolation level repeatable read"}, + inSessionStates: true, + varName: "tx_isolation_one_shot", + expectedValue: "REPEATABLE-READ", + }, + { + inSessionStates: false, + varName: variable.Timestamp, + }, + { + stmts: []string{"set timestamp=100"}, + inSessionStates: true, + varName: variable.Timestamp, + expectedValue: "100", + }, + { + stmts: []string{"set rand_seed1=10000000, rand_seed2=1000000"}, + inSessionStates: true, + varName: variable.RandSeed1, + checkStmt: "select rand()", + expectedValue: "0.028870999839968048", + }, + { + stmts: []string{"set rand_seed1=10000000, rand_seed2=1000000", "select rand()"}, + inSessionStates: true, + varName: variable.RandSeed1, + checkStmt: "select rand()", + expectedValue: "0.11641535266900002", + }, + } + + sem.Enable() + for _, tt := range tests { + tk1 := testkit.NewTestKit(t, store) + for _, stmt := range tt.stmts { + if strings.HasPrefix(stmt, "select") { + tk1.MustQuery(stmt) + } else { + tk1.MustExec(stmt) + } + } + tk2 := testkit.NewTestKit(t, store) + rows := tk1.MustQuery("show session_states").Rows() + state := rows[0][0].(string) + msg := fmt.Sprintf("var name: '%s', expected value: '%s'", tt.varName, tt.expectedValue) + require.Equal(t, tt.inSessionStates, strings.Contains(state, tt.varName), msg) + state = strconv.Quote(state) + setSQL := fmt.Sprintf("set session_states %s", state) + tk2.MustExec(setSQL) + if len(tt.expectedValue) > 0 { + checkStmt := tt.checkStmt + if len(checkStmt) == 0 { + checkStmt = fmt.Sprintf("select @@%s", tt.varName) + } + tk2.MustQuery(checkStmt).Check(testkit.Rows(tt.expectedValue)) + } + } + + { + // The session value should not change even if the global value changes. + tk1 := testkit.NewTestKit(t, store) + tk1.MustQuery("select @@autocommit").Check(testkit.Rows("1")) + tk2 := testkit.NewTestKit(t, store) + tk2.MustExec("set global autocommit=0") + tk3 := testkit.NewTestKit(t, store) + showSessionStatesAndSet(t, tk1, tk3) + tk3.MustQuery("select @@autocommit").Check(testkit.Rows("1")) + } +} + func showSessionStatesAndSet(t *testing.T, tk1, tk2 *testkit.TestKit) { rows := tk1.MustQuery("show session_states").Rows() require.Len(t, rows, 1) state := rows[0][0].(string) - state = strings.ReplaceAll(state, "\\", "\\\\") - state = strings.ReplaceAll(state, "'", "\\'") - setSQL := fmt.Sprintf("set session_states '%s'", state) + state = strconv.Quote(state) + setSQL := fmt.Sprintf("set session_states %s", state) tk2.MustExec(setSQL) } diff --git a/sessionctx/variable/sysvar.go b/sessionctx/variable/sysvar.go index 816757ea46539..ce3cdc66bfc7b 100644 --- a/sessionctx/variable/sysvar.go +++ b/sessionctx/variable/sysvar.go @@ -77,6 +77,9 @@ var defaultSysVars = []*SysVar{ } timestamp := s.StmtCtx.GetOrStoreStmtCache(stmtctx.StmtNowTsCacheKey, time.Now()).(time.Time) return types.ToString(float64(timestamp.UnixNano()) / float64(time.Second)) + }, GetStateValue: func(s *SessionVars) (string, bool, error) { + timestamp, ok := s.systems[Timestamp] + return timestamp, ok && timestamp != DefTimestamp, nil }}, {Scope: ScopeSession, Name: WarningCount, Value: "0", ReadOnly: true, skipInit: true, GetSession: func(s *SessionVars) (string, error) { return strconv.Itoa(s.SysWarningCount), nil @@ -86,9 +89,13 @@ var defaultSysVars = []*SysVar{ }}, {Scope: ScopeSession, Name: LastInsertID, Value: "", skipInit: true, Type: TypeInt, AllowEmpty: true, MinValue: 0, MaxValue: math.MaxInt64, GetSession: func(s *SessionVars) (string, error) { return strconv.FormatUint(s.StmtCtx.PrevLastInsertID, 10), nil + }, GetStateValue: func(s *SessionVars) (string, bool, error) { + return "", false, nil }}, {Scope: ScopeSession, Name: Identity, Value: "", skipInit: true, Type: TypeInt, AllowEmpty: true, MinValue: 0, MaxValue: math.MaxInt64, GetSession: func(s *SessionVars) (string, error) { return strconv.FormatUint(s.StmtCtx.PrevLastInsertID, 10), nil + }, GetStateValue: func(s *SessionVars) (string, bool, error) { + return "", false, nil }}, /* TiDB specific variables */ // TODO: TiDBTxnScope is hidden because local txn feature is not done. @@ -140,11 +147,11 @@ var defaultSysVars = []*SysVar{ } return nil }}, - {Scope: ScopeSession, Name: TiDBOptProjectionPushDown, Value: BoolToOnOff(config.GetGlobalConfig().Performance.ProjectionPushDown), skipInit: true, Type: TypeBool, SetSession: func(s *SessionVars, val string) error { + {Scope: ScopeSession, Name: TiDBOptProjectionPushDown, Value: BoolToOnOff(config.GetGlobalConfig().Performance.ProjectionPushDown), Type: TypeBool, SetSession: func(s *SessionVars, val string) error { s.AllowProjectionPushDown = TiDBOptOn(val) return nil }}, - {Scope: ScopeSession, Name: TiDBOptAggPushDown, Value: BoolToOnOff(DefOptAggPushDown), Type: TypeBool, skipInit: true, SetSession: func(s *SessionVars, val string) error { + {Scope: ScopeSession, Name: TiDBOptAggPushDown, Value: BoolToOnOff(DefOptAggPushDown), Type: TypeBool, SetSession: func(s *SessionVars, val string) error { s.AllowAggPushDown = TiDBOptOn(val) return nil }}, @@ -192,6 +199,11 @@ var defaultSysVars = []*SysVar{ s.txnIsolationLevelOneShot.state = oneShotSet s.txnIsolationLevelOneShot.value = val return nil + }, GetStateValue: func(s *SessionVars) (string, bool, error) { + if s.txnIsolationLevelOneShot.state != oneShotDef { + return s.txnIsolationLevelOneShot.value, true, nil + } + return "", false, nil }}, {Scope: ScopeSession, Name: TiDBOptimizerSelectivityLevel, Value: strconv.Itoa(DefTiDBOptimizerSelectivityLevel), skipInit: true, Type: TypeUnsigned, MinValue: 0, MaxValue: math.MaxInt32, SetSession: func(s *SessionVars, val string) error { s.OptimizerSelectivityLevel = tidbOptPositiveInt32(val, DefTiDBOptimizerSelectivityLevel) @@ -307,12 +319,16 @@ var defaultSysVars = []*SysVar{ return nil }, GetSession: func(s *SessionVars) (string, error) { return "0", nil + }, GetStateValue: func(s *SessionVars) (string, bool, error) { + return strconv.FormatUint(uint64(s.Rng.GetSeed1()), 10), true, nil }}, {Scope: ScopeSession, Name: RandSeed2, Type: TypeInt, Value: "0", skipInit: true, MaxValue: math.MaxInt32, SetSession: func(s *SessionVars, val string) error { s.Rng.SetSeed2(uint32(tidbOptPositiveInt32(val, 0))) return nil }, GetSession: func(s *SessionVars) (string, error) { return "0", nil + }, GetStateValue: func(s *SessionVars) (string, bool, error) { + return strconv.FormatUint(uint64(s.Rng.GetSeed2()), 10), true, nil }}, {Scope: ScopeSession, Name: TiDBReadConsistency, Value: string(ReadConsistencyStrict), Type: TypeStr, Hidden: true, Validation: func(_ *SessionVars, normalized string, _ string, _ ScopeFlag) (string, error) { @@ -800,6 +816,12 @@ var defaultSysVars = []*SysVar{ }, GetGlobal: func(s *SessionVars) (string, error) { return BoolToOnOff(EnableConcurrentDDL.Load()), nil }}, + {Scope: ScopeGlobal, Name: TiDBEnableNoopVariables, Value: BoolToOnOff(DefTiDBEnableNoopVariables), Type: TypeEnum, PossibleValues: []string{Off, On, Warn}, SetGlobal: func(s *SessionVars, val string) error { + EnableNoopVariables.Store(TiDBOptOn(val)) + return nil + }, GetGlobal: func(s *SessionVars) (string, error) { + return BoolToOnOff(EnableNoopVariables.Load()), nil + }}, /* The system variables below have GLOBAL and SESSION scope */ {Scope: ScopeGlobal | ScopeSession, Name: SQLSelectLimit, Value: "18446744073709551615", Type: TypeUnsigned, MinValue: 0, MaxValue: math.MaxUint64, SetSession: func(s *SessionVars, val string) error { diff --git a/sessionctx/variable/tidb_vars.go b/sessionctx/variable/tidb_vars.go index 64c1916292c73..03eca96b0e20c 100644 --- a/sessionctx/variable/tidb_vars.go +++ b/sessionctx/variable/tidb_vars.go @@ -670,6 +670,9 @@ const ( // TiDBQueryLogMaxLen is used to set the max length of the query in the log. TiDBQueryLogMaxLen = "tidb_query_log_max_len" + // TiDBEnableNoopVariables is used to indicate if noops appear in SHOW [GLOBAL] VARIABLES + TiDBEnableNoopVariables = "tidb_enable_noop_variables" + // TiDBNonTransactionalIgnoreError is used to ignore error in non-transactional DMLs. // When set to false, a non-transactional DML returns when it meets the first error. // When set to true, a non-transactional DML finishes all batches even if errors are met in some batches. @@ -855,6 +858,7 @@ const ( DefTiDBWaitSplitRegionFinish = true DefWaitSplitRegionTimeout = 300 // 300s DefTiDBEnableNoopFuncs = Off + DefTiDBEnableNoopVariables = true DefTiDBAllowRemoveAutoInc = false DefTiDBUsePlanBaselines = true DefTiDBEvolvePlanBaselines = false @@ -985,6 +989,7 @@ var ( PreparedPlanCacheSize = atomic.NewUint64(DefTiDBPrepPlanCacheSize) PreparedPlanCacheMemoryGuardRatio = atomic.NewFloat64(DefTiDBPrepPlanCacheMemoryGuardRatio) EnableConcurrentDDL = atomic.NewBool(DefTiDBEnableConcurrentDDL) + EnableNoopVariables = atomic.NewBool(DefTiDBEnableNoopVariables) ) var ( diff --git a/sessionctx/variable/variable.go b/sessionctx/variable/variable.go index db747819dee42..8a882f1d6e4f2 100644 --- a/sessionctx/variable/variable.go +++ b/sessionctx/variable/variable.go @@ -132,6 +132,9 @@ type SysVar struct { GetSession func(*SessionVars) (string, error) // GetGlobal is a getter function for global scope. GetGlobal func(*SessionVars) (string, error) + // GetStateValue gets the value for session states, which is used for migrating sessions. + // We need a function to override GetSession sometimes, because GetSession may not return the real value. + GetStateValue func(*SessionVars) (string, bool, error) // skipInit defines if the sysvar should be loaded into the session on init. // This is only important to set for sysvars that include session scope, // since global scoped sysvars are not-applicable. diff --git a/sessionctx/variable/varsutil.go b/sessionctx/variable/varsutil.go index 39ec20cbe2fb1..ab878d2bb3054 100644 --- a/sessionctx/variable/varsutil.go +++ b/sessionctx/variable/varsutil.go @@ -193,6 +193,29 @@ func GetSessionOrGlobalSystemVar(s *SessionVars, name string) (string, error) { return sv.GetGlobalFromHook(s) } +// GetSessionStatesSystemVar gets the session variable value for session states. +// It's only used for encoding session states when migrating a session. +// The returned boolean indicates whether to keep this value in the session states. +func GetSessionStatesSystemVar(s *SessionVars, name string) (string, bool, error) { + sv := GetSysVar(name) + if sv == nil { + return "", false, ErrUnknownSystemVar.GenWithStackByArgs(name) + } + // Call GetStateValue first if it exists. Otherwise, call GetSession. + if sv.GetStateValue != nil { + return sv.GetStateValue(s) + } + if sv.GetSession != nil { + val, err := sv.GetSessionFromHook(s) + return val, err == nil, err + } + // Only get the cached value. No need to check the global or default value. + if val, ok := s.systems[sv.Name]; ok { + return val, true, nil + } + return "", false, nil +} + // GetGlobalSystemVar gets a global system variable. func GetGlobalSystemVar(s *SessionVars, name string) (string, error) { sv := GetSysVar(name) diff --git a/sessionctx/variable/varsutil_test.go b/sessionctx/variable/varsutil_test.go index 59cfb4cfca81b..4641a8c2f1e0d 100644 --- a/sessionctx/variable/varsutil_test.go +++ b/sessionctx/variable/varsutil_test.go @@ -673,3 +673,22 @@ func TestStmtVars(t *testing.T) { err = SetStmtVar(vars, MaxExecutionTime, "100") require.NoError(t, err) } + +func TestSessionStatesSystemVar(t *testing.T) { + vars := NewSessionVars() + err := SetSessionSystemVar(vars, "autocommit", "1") + require.NoError(t, err) + val, keep, err := GetSessionStatesSystemVar(vars, "autocommit") + require.NoError(t, err) + require.Equal(t, "ON", val) + require.Equal(t, true, keep) + _, keep, err = GetSessionStatesSystemVar(vars, Timestamp) + require.NoError(t, err) + require.Equal(t, false, keep) + err = SetSessionSystemVar(vars, MaxAllowedPacket, "1024") + require.NoError(t, err) + val, keep, err = GetSessionStatesSystemVar(vars, MaxAllowedPacket) + require.NoError(t, err) + require.Equal(t, "1024", val) + require.Equal(t, true, keep) +} diff --git a/util/mathutil/rand.go b/util/mathutil/rand.go index 6c93588a91129..a58c88281d638 100644 --- a/util/mathutil/rand.go +++ b/util/mathutil/rand.go @@ -67,3 +67,17 @@ func (rng *MysqlRng) SetSeed2(seed uint32) { defer rng.mu.Unlock() rng.seed2 = seed } + +// GetSeed1 is an interface to get seed1. It's only used for getting session states. +func (rng *MysqlRng) GetSeed1() uint32 { + rng.mu.Lock() + defer rng.mu.Unlock() + return rng.seed1 +} + +// GetSeed2 is an interface to get seed2. It's only used for getting session states. +func (rng *MysqlRng) GetSeed2() uint32 { + rng.mu.Lock() + defer rng.mu.Unlock() + return rng.seed2 +} diff --git a/util/mathutil/rand_test.go b/util/mathutil/rand_test.go index d0164f4de201f..0cc026604431c 100644 --- a/util/mathutil/rand_test.go +++ b/util/mathutil/rand_test.go @@ -68,4 +68,6 @@ func TestRandWithSeed1AndSeed2(t *testing.T) { require.Equal(t, rng.Gen(), 0.028870999839968048) require.Equal(t, rng.Gen(), 0.11641535266900002) require.Equal(t, rng.Gen(), 0.49546379455874096) + require.Equal(t, rng.GetSeed1(), uint32(532000198)) + require.Equal(t, rng.GetSeed2(), uint32(689000330)) }