diff --git a/br/cmd/tidb-lightning/main.go b/br/cmd/tidb-lightning/main.go index 84362433f222c..41c2e8bebc630 100644 --- a/br/cmd/tidb-lightning/main.go +++ b/br/cmd/tidb-lightning/main.go @@ -23,6 +23,7 @@ import ( "syscall" "github.com/pingcap/tidb/br/pkg/lightning" + "github.com/pingcap/tidb/br/pkg/lightning/common" "github.com/pingcap/tidb/br/pkg/lightning/config" "github.com/pingcap/tidb/br/pkg/lightning/log" "go.uber.org/zap" @@ -89,12 +90,21 @@ func main() { return app.RunOnce(context.Background(), cfg, nil) }() + finished := true + if common.IsContextCanceledError(err) { + err = nil + finished = false + } if err != nil { logger.Error("tidb lightning encountered error stack info", zap.Error(err)) fmt.Fprintln(os.Stderr, "tidb lightning encountered error: ", err) } else { - logger.Info("tidb lightning exit") - fmt.Fprintln(os.Stdout, "tidb lightning exit") + logger.Info("tidb lightning exit", zap.Bool("finished", finished)) + exitMsg := "tidb lightning exit successfully" + if finished { + exitMsg = "tidb lightning canceled" + } + fmt.Fprintln(os.Stdout, exitMsg) } // call Sync() with log to stdout may return error in some case, so just skip it diff --git a/br/pkg/lightning/backend/kv/kv2sql.go b/br/pkg/lightning/backend/kv/kv2sql.go index a3c188a81eea7..47b9aa5393b2d 100644 --- a/br/pkg/lightning/backend/kv/kv2sql.go +++ b/br/pkg/lightning/backend/kv/kv2sql.go @@ -17,7 +17,6 @@ package kv import ( "fmt" - "github.com/pingcap/tidb/br/pkg/lightning/metric" "github.com/pingcap/tidb/kv" "github.com/pingcap/tidb/parser/model" "github.com/pingcap/tidb/table" @@ -38,15 +37,10 @@ func (t *TableKVDecoder) Name() string { return t.tableName } -func (t *TableKVDecoder) DecodeHandleFromTable(key []byte) (kv.Handle, error) { +func (t *TableKVDecoder) DecodeHandleFromRowKey(key []byte) (kv.Handle, error) { return tablecodec.DecodeRowKey(key) } -func (t *TableKVDecoder) EncodeHandleKey(tableID int64, h kv.Handle) kv.Key { - // do not ever ever use tbl.Meta().ID, we need to deal with partitioned tables! - return tablecodec.EncodeRowKeyWithHandle(tableID, h) -} - func (t *TableKVDecoder) DecodeHandleFromIndex(indexInfo *model.IndexInfo, key []byte, value []byte) (kv.Handle, error) { cols := tables.BuildRowcodecColInfoForIndexColumns(indexInfo, t.tbl.Meta()) return tablecodec.DecodeIndexHandle(key, value, len(cols)) @@ -111,7 +105,6 @@ func (t *TableKVDecoder) IterRawIndexKeys(h kv.Handle, rawRow []byte, fn func([] } func NewTableKVDecoder(tbl table.Table, tableName string, options *SessionOptions) (*TableKVDecoder, error) { - metric.KvEncoderCounter.WithLabelValues("open").Inc() se := newSession(options) cols := tbl.Cols() // Set CommonAddRecordCtx to session to reuse the slices and BufStore in AddRecord diff --git a/br/pkg/lightning/backend/kv/sql2kv_test.go b/br/pkg/lightning/backend/kv/sql2kv_test.go index 1a4e48bfe1c00..b9569e8957b18 100644 --- a/br/pkg/lightning/backend/kv/sql2kv_test.go +++ b/br/pkg/lightning/backend/kv/sql2kv_test.go @@ -158,7 +158,7 @@ func (s *kvSuite) TestDecode(c *C) { Key: []byte{0x74, 0x80, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x5f, 0x72, 0x80, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1}, Val: []byte{0x8, 0x2, 0x8, 0x2}, } - h, err := decoder.DecodeHandleFromTable(p.Key) + h, err := decoder.DecodeHandleFromRowKey(p.Key) c.Assert(err, IsNil) c.Assert(p.Val, NotNil) rows, _, err := decoder.DecodeRawRowData(h, p.Val) @@ -215,7 +215,7 @@ func (s *kvSuite) TestDecodeIndex(c *C) { Timestamp: 1234567890, }) c.Assert(err, IsNil) - h1, err := decoder.DecodeHandleFromTable(data.pairs[0].Key) + h1, err := decoder.DecodeHandleFromRowKey(data.pairs[0].Key) c.Assert(err, IsNil) h2, err := decoder.DecodeHandleFromIndex(tbl.Indices()[0].Meta(), data.pairs[1].Key, data.pairs[1].Val) c.Assert(err, IsNil) diff --git a/br/pkg/lightning/backend/local/duplicate.go b/br/pkg/lightning/backend/local/duplicate.go index 25872e19b73b5..6b0c9f9d66978 100644 --- a/br/pkg/lightning/backend/local/duplicate.go +++ b/br/pkg/lightning/backend/local/duplicate.go @@ -20,13 +20,16 @@ import ( "io" "math" "sort" - "time" + "sync" "github.com/cockroachdb/pebble" + "github.com/docker/go-units" + "github.com/google/btree" "github.com/pingcap/errors" + "github.com/pingcap/kvproto/pkg/errorpb" "github.com/pingcap/kvproto/pkg/import_sstpb" "github.com/pingcap/kvproto/pkg/kvrpcpb" - "github.com/pingcap/kvproto/pkg/metapb" + pkgkv "github.com/pingcap/tidb/br/pkg/kv" "github.com/pingcap/tidb/br/pkg/lightning/backend/kv" "github.com/pingcap/tidb/br/pkg/lightning/common" "github.com/pingcap/tidb/br/pkg/lightning/errormanager" @@ -40,42 +43,19 @@ import ( "github.com/pingcap/tidb/table" "github.com/pingcap/tidb/tablecodec" "github.com/pingcap/tidb/util/codec" + "github.com/pingcap/tidb/util/hack" "github.com/pingcap/tidb/util/ranger" - tikvclient "github.com/tikv/client-go/v2/tikv" + "github.com/tikv/client-go/v2/tikv" "go.uber.org/atomic" "go.uber.org/zap" - "go.uber.org/zap/zapcore" "golang.org/x/sync/errgroup" - "google.golang.org/grpc" - "google.golang.org/grpc/backoff" - "google.golang.org/grpc/credentials" - "google.golang.org/grpc/keepalive" ) const ( - maxGetRequestKeyCount = 1024 + maxDupCollectAttemptTimes = 5 + defaultRecordConflictErrorBatch = 1024 ) -type DuplicateRequest struct { - tableID int64 - start tidbkv.Key - end tidbkv.Key - indexInfo *model.IndexInfo -} - -type DuplicateManager struct { - errorMgr *errormanager.ErrorManager - splitCli restore.SplitClient - tikvCli *tikvclient.KVStore - regionConcurrency int - connPool common.GRPCConns - tls *common.TLS - ts uint64 - keyAdapter KeyAdapter - remoteWorkerPool *utils.WorkerPool - opts *kv.SessionOptions -} - type pendingIndexHandles struct { // all 4 slices should have exactly the same length. // we use a struct-of-arrays instead of array-of-structs @@ -165,6 +145,87 @@ func (indexHandles *pendingIndexHandles) searchSortedRawHandle(rawHandle []byte) }) } +type pendingKeyRange tidbkv.KeyRange + +func (kr pendingKeyRange) Less(other btree.Item) bool { + return bytes.Compare(kr.EndKey, other.(pendingKeyRange).EndKey) < 0 +} + +type pendingKeyRanges struct { + mu sync.Mutex + tree *btree.BTree +} + +func newPendingKeyRanges(keyRange tidbkv.KeyRange) *pendingKeyRanges { + tree := btree.New(32) + tree.ReplaceOrInsert(pendingKeyRange(keyRange)) + return &pendingKeyRanges{tree: tree} +} + +func (p *pendingKeyRanges) list() []tidbkv.KeyRange { + p.mu.Lock() + defer p.mu.Unlock() + + var keyRanges []tidbkv.KeyRange + p.tree.Ascend(func(item btree.Item) bool { + keyRanges = append(keyRanges, tidbkv.KeyRange(item.(pendingKeyRange))) + return true + }) + return keyRanges +} + +func (p *pendingKeyRanges) empty() bool { + return p.tree.Len() == 0 +} + +func (p *pendingKeyRanges) finish(keyRange tidbkv.KeyRange) { + p.mu.Lock() + defer p.mu.Unlock() + + var ( + pendingAdd []btree.Item + pendingRemove []btree.Item + ) + startKey := keyRange.StartKey + endKey := keyRange.EndKey + p.tree.AscendGreaterOrEqual( + pendingKeyRange(tidbkv.KeyRange{EndKey: startKey}), + func(item btree.Item) bool { + kr := item.(pendingKeyRange) + if bytes.Compare(startKey, kr.EndKey) >= 0 { + return true + } + if bytes.Compare(endKey, kr.StartKey) <= 0 { + return false + } + pendingRemove = append(pendingRemove, kr) + if bytes.Compare(startKey, kr.StartKey) > 0 { + pendingAdd = append(pendingAdd, + pendingKeyRange(tidbkv.KeyRange{ + StartKey: kr.StartKey, + EndKey: startKey, + }), + ) + } + if bytes.Compare(endKey, kr.EndKey) < 0 { + pendingAdd = append(pendingAdd, + pendingKeyRange(tidbkv.KeyRange{ + StartKey: endKey, + EndKey: kr.EndKey, + }), + ) + } + return true + }, + ) + for _, item := range pendingRemove { + p.tree.Delete(item) + } + for _, item := range pendingAdd { + p.tree.ReplaceOrInsert(item) + } +} + // physicalTableIDs returns all physical table IDs associated with the tableInfo. // A partitioned table can have multiple physical table IDs. func physicalTableIDs(tableInfo *model.TableInfo) []int64 { @@ -180,645 +241,632 @@ func physicalTableIDs(tableInfo *model.TableInfo) []int64 { return []int64{tableInfo.ID} } -// NewDuplicateManager creates a new *DuplicateManager. -// -// This object provides methods to collect and decode duplicated KV pairs into row data. The results -// are stored into the errorMgr. -func NewDuplicateManager(local *local, ts uint64, opts *kv.SessionOptions) (*DuplicateManager, error) { - return &DuplicateManager{ - errorMgr: local.errorMgr, - tls: local.tls, - regionConcurrency: local.tcpConcurrency, - splitCli: local.splitCli, - tikvCli: local.tikvCli, - keyAdapter: dupDetectKeyAdapter{}, - ts: ts, - connPool: common.NewGRPCConns(), - // TODO: not sure what is the correct concurrency value. - remoteWorkerPool: utils.NewWorkerPool(uint(local.tcpConcurrency), "duplicates"), - opts: opts, - }, nil -} - -// CollectDuplicateRowsFromTiKV collects duplicated rows already imported into TiKV. -// -// Collection result are saved into the ErrorManager. -func (manager *DuplicateManager) CollectDuplicateRowsFromTiKV( - ctx context.Context, - tbl table.Table, - tableName string, -) (hasDupe bool, err error) { - logTask := log.With(zap.String("table", tableName)).Begin(zapcore.InfoLevel, "collect duplicate data from remote TiKV") - defer func() { - logTask.End(zapcore.InfoLevel, err) - }() - - reqs, err := buildDuplicateRequests(tbl.Meta()) - if err != nil { - return false, err +// tableHandleKeyRanges returns all key ranges associated with the tableInfo. +func tableHandleKeyRanges(tableInfo *model.TableInfo) ([]tidbkv.KeyRange, error) { + ranges := ranger.FullIntRange(false) + if tableInfo.IsCommonHandle { + ranges = ranger.FullRange() } + tableIDs := physicalTableIDs(tableInfo) + return distsql.TableHandleRangesToKVRanges(nil, tableIDs, tableInfo.IsCommonHandle, ranges, nil) +} - // TODO: reuse the *kv.SessionOptions from NewEncoder for picking the correct time zone. - decoder, err := kv.NewTableKVDecoder(tbl, tableName, manager.opts) - if err != nil { - return false, err - } - g, rpcctx := errgroup.WithContext(ctx) - atomicHasDupe := atomic.NewBool(false) - for _, r := range reqs { - req := r - manager.remoteWorkerPool.ApplyOnErrorGroup(g, func() error { - err := manager.sendRequestToTiKV(rpcctx, decoder, req, atomicHasDupe) - if err != nil { - log.L().Error("error occur when collect duplicate data from TiKV", zap.Error(err)) - } - return err - }) +// tableIndexKeyRanges returns all key ranges associated with the tableInfo and indexInfo. +func tableIndexKeyRanges(tableInfo *model.TableInfo, indexInfo *model.IndexInfo) ([]tidbkv.KeyRange, error) { + tableIDs := physicalTableIDs(tableInfo) + var keyRanges []tidbkv.KeyRange + for _, tid := range tableIDs { + partitionKeysRanges, err := distsql.IndexRangesToKVRanges(nil, tid, indexInfo.ID, ranger.FullRange(), nil) + if err != nil { + return nil, errors.Trace(err) + } + keyRanges = append(keyRanges, partitionKeysRanges...) } - err = errors.Trace(g.Wait()) - return atomicHasDupe.Load(), err + return keyRanges, nil } -func (manager *DuplicateManager) sendRequestToTiKV(ctx context.Context, - decoder *kv.TableKVDecoder, - req *DuplicateRequest, - hasDupe *atomic.Bool, -) error { - logger := log.With( - zap.String("table", decoder.Name()), - zap.Int64("tableID", req.tableID), - logutil.Key("startKey", req.start), - logutil.Key("endKey", req.end)) +// DupKVStream is a streaming interface for collecting duplicate key-value pairs. +type DupKVStream interface { + // Next returns the next key-value pair or any error it encountered. + // At the end of the stream, the error is io.EOF. + Next() (key, val []byte, err error) + // Close closes the stream. + Close() error +} - startKey := codec.EncodeBytes([]byte{}, req.start) - endKey := codec.EncodeBytes([]byte{}, req.end) +// LocalDupKVStream implements the interface of DupKVStream. +// It collects duplicate key-value pairs from a pebble.DB. +//goland:noinspection GoNameStartsWithPackageName +type LocalDupKVStream struct { + iter pkgkv.Iter +} - regions, err := restore.PaginateScanRegion(ctx, manager.splitCli, startKey, endKey, scanRegionLimit) - if err != nil { - return err +// NewLocalDupKVStream creates a new LocalDupKVStream with the given duplicate db and key range. +func NewLocalDupKVStream(dupDB *pebble.DB, keyAdapter KeyAdapter, keyRange tidbkv.KeyRange) *LocalDupKVStream { + opts := &pebble.IterOptions{ + LowerBound: keyRange.StartKey, + UpperBound: keyRange.EndKey, } - tryTimes := 0 - indexHandles := makePendingIndexHandlesWithCapacity(0) - for len(regions) > 0 { - if tryTimes > maxRetryTimes { - return errors.Errorf("retry time exceed limit") - } - unfinishedRegions := make([]*restore.RegionInfo, 0) - waitingClients := make([]import_sstpb.ImportSST_DuplicateDetectClient, 0) - watingRegions := make([]*restore.RegionInfo, 0) - for idx, region := range regions { - if len(waitingClients) > manager.regionConcurrency { - r := regions[idx:] - unfinishedRegions = append(unfinishedRegions, r...) - break - } - _, start, _ := codec.DecodeBytes(region.Region.StartKey, []byte{}) - _, end, _ := codec.DecodeBytes(region.Region.EndKey, []byte{}) - if bytes.Compare(startKey, region.Region.StartKey) > 0 { - start = req.start - } - if region.Region.EndKey == nil || len(region.Region.EndKey) == 0 || bytes.Compare(endKey, region.Region.EndKey) < 0 { - end = req.end - } - - logger.Debug("[detect-dupe] get duplicate stream", - zap.Int("localStreamID", idx), - logutil.Region(region.Region), - logutil.Leader(region.Leader), - logutil.Key("regionStartKey", start), - logutil.Key("regionEndKey", end)) - cli, err := manager.getDuplicateStream(ctx, region, start, end) - if err != nil { - r, err := manager.splitCli.GetRegionByID(ctx, region.Region.GetId()) - if err != nil { - unfinishedRegions = append(unfinishedRegions, region) - } else { - unfinishedRegions = append(unfinishedRegions, r) - } - } else { - waitingClients = append(waitingClients, cli) - watingRegions = append(watingRegions, region) - } - } + iter := newDupDBIter(dupDB, keyAdapter, opts) + iter.First() + return &LocalDupKVStream{iter: iter} +} - if indexHandles.Len() > 0 { - handles := manager.getValues(ctx, decoder, indexHandles) - if handles.Len() > 0 { - indexHandles = handles - } else { - indexHandles.truncate() - } +func (s *LocalDupKVStream) Next() (key, val []byte, err error) { + if !s.iter.Valid() { + err = s.iter.Error() + if err == nil { + err = io.EOF } + return + } + key = append(key, s.iter.Key()...) + val = append(val, s.iter.Value()...) + s.iter.Next() + return +} - for idx, cli := range waitingClients { - region := watingRegions[idx] - cliLogger := logger.With( - zap.Int("localStreamID", idx), - logutil.Region(region.Region), - logutil.Leader(region.Leader)) - for { - resp, reqErr := cli.Recv() - hasErr := false - if reqErr != nil { - if errors.Cause(reqErr) == io.EOF { - cliLogger.Debug("[detect-dupe] exhausted duplication stream") - break - } - hasErr = true - } - - if hasErr || resp.GetKeyError() != nil { - r, err := manager.splitCli.GetRegionByID(ctx, region.Region.GetId()) - if err != nil { - unfinishedRegions = append(unfinishedRegions, region) - } else { - unfinishedRegions = append(unfinishedRegions, r) - } - } - if hasErr { - cliLogger.Warn("[detect-dupe] meet error when recving duplicate detect response from TiKV, retry again", - zap.Error(reqErr)) - break - } - if resp.GetKeyError() != nil { - cliLogger.Warn("[detect-dupe] meet key error in duplicate detect response from TiKV, retry again ", - zap.String("KeyError", resp.GetKeyError().GetMessage())) - break - } +func (s *LocalDupKVStream) Close() error { + return s.iter.Close() +} - if resp.GetRegionError() != nil { - cliLogger.Warn("[detect-dupe] meet key error in duplicate detect response from TiKV, retry again ", - zap.String("RegionError", resp.GetRegionError().GetMessage())) - - r, err := restore.PaginateScanRegion(ctx, manager.splitCli, watingRegions[idx].Region.GetStartKey(), watingRegions[idx].Region.GetEndKey(), scanRegionLimit) - if err != nil { - unfinishedRegions = append(unfinishedRegions, watingRegions[idx]) - } else { - unfinishedRegions = append(unfinishedRegions, r...) - } - break - } +type regionError struct { + inner *errorpb.Error +} - if len(resp.Pairs) > 0 { - hasDupe.Store(true) - } +func (r regionError) Error() string { + return r.inner.String() +} - handles, err := manager.storeDuplicateData(ctx, resp, decoder, req) - if err != nil { - return err - } - if handles.Len() > 0 { - indexHandles.extend(&handles) - } - } - } +// RemoteDupKVStream implements the interface of DupKVStream. +// It collects duplicate key-value pairs from a TiKV region. +type RemoteDupKVStream struct { + cli import_sstpb.ImportSST_DuplicateDetectClient + kvs []*import_sstpb.KvPair + atEOF bool + cancel context.CancelFunc +} - // it means that all the regions sent to TiKV fail, so we must sleep for a while to avoid retrying too frequently. - if len(unfinishedRegions) == len(regions) { - tryTimes += 1 - time.Sleep(defaultRetryBackoffTime) - } - regions = unfinishedRegions +func getDupDetectClient( + ctx context.Context, + region *restore.RegionInfo, + keyRange tidbkv.KeyRange, + importClientFactory ImportClientFactory, +) (import_sstpb.ImportSST_DuplicateDetectClient, error) { + leader := region.Leader + if leader == nil { + leader = region.Region.GetPeers()[0] } - return nil + importClient, err := importClientFactory.Create(ctx, leader.GetStoreId()) + if err != nil { + return nil, errors.Trace(err) + } + reqCtx := &kvrpcpb.Context{ + RegionId: region.Region.GetId(), + RegionEpoch: region.Region.GetRegionEpoch(), + Peer: leader, + } + req := &import_sstpb.DuplicateDetectRequest{ + Context: reqCtx, + StartKey: keyRange.StartKey, + EndKey: keyRange.EndKey, + } + cli, err := importClient.DuplicateDetect(ctx, req) + if err != nil { + return nil, errors.Trace(err) + } + return cli, nil } -func (manager *DuplicateManager) storeDuplicateData( +// NewRemoteDupKVStream creates a new RemoteDupKVStream. +func NewRemoteDupKVStream( ctx context.Context, - resp *import_sstpb.DuplicateDetectResponse, - decoder *kv.TableKVDecoder, - req *DuplicateRequest, -) (pendingIndexHandles, error) { - var err error - var dataConflictInfos []errormanager.DataConflictInfo - indexHandles := makePendingIndexHandlesWithCapacity(len(resp.Pairs)) - - loggerIndexName := "PRIMARY" - if req.indexInfo != nil { - loggerIndexName = req.indexInfo.Name.O + region *restore.RegionInfo, + keyRange tidbkv.KeyRange, + importClientFactory ImportClientFactory, +) (*RemoteDupKVStream, error) { + subCtx, cancel := context.WithCancel(ctx) + cli, err := getDupDetectClient(subCtx, region, keyRange, importClientFactory) + if err != nil { + cancel() + return nil, errors.Trace(err) } - superLogger := log.With( - zap.String("table", decoder.Name()), - zap.Int64("tableID", req.tableID), - zap.String("index", loggerIndexName)) - - for _, kv := range resp.Pairs { - logger := superLogger.With( - logutil.Key("key", kv.Key), logutil.Key("value", kv.Value), - zap.Uint64("commit-ts", kv.CommitTs)) + s := &RemoteDupKVStream{cli: cli, cancel: cancel} + // call tryRecv to see if there are some region errors. + if err := s.tryRecv(); err != nil && errors.Cause(err) != io.EOF { + cancel() + return nil, errors.Trace(err) + } + return s, nil +} - var h tidbkv.Handle - if req.indexInfo != nil { - h, err = decoder.DecodeHandleFromIndex(req.indexInfo, kv.Key, kv.Value) - } else { - h, err = decoder.DecodeHandleFromTable(kv.Key) - } - if err != nil { - logger.Error("decode handle error", log.ShortError(err)) - continue +func (s *RemoteDupKVStream) tryRecv() error { + resp, err := s.cli.Recv() + if err != nil { + if errors.Cause(err) == io.EOF { + s.atEOF = true + err = io.EOF } - logger.Debug("[detect-dupe] remote dupe response", - logutil.Redact(zap.Stringer("handle", h))) + return err + } + if resp.RegionError != nil { + return errors.Cause(regionError{inner: resp.RegionError}) + } + if resp.KeyError != nil { + return errors.Errorf("meet key error in duplicate detect response: %s", resp.KeyError.Message) + } + s.kvs = resp.Pairs + return nil +} - conflictInfo := errormanager.DataConflictInfo{ - RawKey: kv.Key, - RawValue: kv.Value, - KeyData: h.String(), +func (s *RemoteDupKVStream) Next() (key, val []byte, err error) { + for len(s.kvs) == 0 { + if s.atEOF { + return nil, nil, io.EOF } - - if req.indexInfo != nil { - indexHandles.append( - conflictInfo, - req.indexInfo.Name.O, - h, decoder.EncodeHandleKey(req.tableID, h)) - } else { - conflictInfo.Row = decoder.DecodeRawRowDataAsStr(h, kv.Value) - dataConflictInfos = append(dataConflictInfos, conflictInfo) + if err := s.tryRecv(); err != nil { + return nil, nil, errors.Trace(err) } } + key, val = s.kvs[0].Key, s.kvs[0].Value + s.kvs = s.kvs[1:] + return +} - err = manager.errorMgr.RecordDataConflictError(ctx, log.L(), decoder.Name(), dataConflictInfos) - if err != nil { - return indexHandles, err - } +func (s *RemoteDupKVStream) Close() error { + s.cancel() + return nil +} - if len(indexHandles.dataConflictInfos) == 0 { - return indexHandles, nil - } - return manager.getValues(ctx, decoder, indexHandles), nil +// DuplicateManager provides methods to collect and decode duplicated KV pairs into row data. The results +// are stored into the errorMgr. +type DuplicateManager struct { + tbl table.Table + tableName string + splitCli restore.SplitClient + tikvCli *tikv.KVStore + errorMgr *errormanager.ErrorManager + decoder *kv.TableKVDecoder + logger log.Logger + concurrency int + hasDupe *atomic.Bool } -// CollectDuplicateRowsFromLocalIndex collects rows by read the index in db. -func (manager *DuplicateManager) CollectDuplicateRowsFromLocalIndex( - ctx context.Context, +// NewDuplicateManager creates a new DuplicateManager. +func NewDuplicateManager( tbl table.Table, tableName string, - db *pebble.DB, -) (bool, error) { - // TODO: reuse the *kv.SessionOptions from NewEncoder for picking the correct time zone. - decoder, err := kv.NewTableKVDecoder(tbl, tableName, manager.opts) + splitCli restore.SplitClient, + tikvCli *tikv.KVStore, + errMgr *errormanager.ErrorManager, + sessOpts *kv.SessionOptions, + concurrency int, + hasDupe *atomic.Bool, +) (*DuplicateManager, error) { + decoder, err := kv.NewTableKVDecoder(tbl, tableName, sessOpts) if err != nil { - return false, errors.Trace(err) + return nil, errors.Trace(err) } + logger := log.With(zap.String("tableName", tableName)) + return &DuplicateManager{ + tbl: tbl, + tableName: tableName, + splitCli: splitCli, + tikvCli: tikvCli, + errorMgr: errMgr, + decoder: decoder, + logger: logger, + concurrency: concurrency, + hasDupe: hasDupe, + }, nil +} - logger := log.With(zap.String("table", tableName)) - - allRanges := make([]tidbkv.KeyRange, 0) - tableIDs := physicalTableIDs(tbl.Meta()) - // Collect row handle duplicates. +// RecordDataConflictError records data conflicts to errorMgr. The key received from stream must be a row key. +func (m *DuplicateManager) RecordDataConflictError(ctx context.Context, stream DupKVStream) error { + defer stream.Close() var dataConflictInfos []errormanager.DataConflictInfo - hasDataConflict := false - { - ranges := ranger.FullIntRange(false) - if tbl.Meta().IsCommonHandle { - ranges = ranger.FullRange() + for { + key, val, err := stream.Next() + if errors.Cause(err) == io.EOF { + break } - keyRanges, err := distsql.TableHandleRangesToKVRanges(nil, tableIDs, tbl.Meta().IsCommonHandle, ranges, nil) if err != nil { - return false, errors.Trace(err) + return errors.Trace(err) } - allRanges = append(allRanges, keyRanges...) - for _, r := range keyRanges { - logger.Debug("[detect-dupe] collect local range", - logutil.Key("startKey", r.StartKey), - logutil.Key("endKey", r.EndKey)) - startKey := codec.EncodeBytes([]byte{}, r.StartKey) - endKey := codec.EncodeBytes([]byte{}, r.EndKey) - opts := &pebble.IterOptions{ - LowerBound: startKey, - UpperBound: endKey, - } + m.hasDupe.Store(true) - if err := func() error { - iter := db.NewIter(opts) - defer iter.Close() - - for iter.First(); iter.Valid(); iter.Next() { - hasDataConflict = true - rawKey, err := manager.keyAdapter.Decode(nil, iter.Key()) - if err != nil { - return err - } - rawValue := make([]byte, len(iter.Value())) - copy(rawValue, iter.Value()) - - h, err := decoder.DecodeHandleFromTable(rawKey) - if err != nil { - return err - } - logger.Debug("[detect-dupe] found local data conflict", - logutil.Key("key", rawKey), - logutil.Key("value", rawValue), - logutil.Redact(zap.Stringer("handle", h))) - - conflictInfo := errormanager.DataConflictInfo{ - RawKey: rawKey, - RawValue: rawValue, - KeyData: h.String(), - Row: decoder.DecodeRawRowDataAsStr(h, rawValue), - } - dataConflictInfos = append(dataConflictInfos, conflictInfo) - } - if err := iter.Error(); err != nil { - return err - } - if err := manager.errorMgr.RecordDataConflictError(ctx, log.L(), decoder.Name(), dataConflictInfos); err != nil { - return err - } - dataConflictInfos = dataConflictInfos[:0] - return nil - }(); err != nil { - return false, errors.Trace(err) - } - db.DeleteRange(startKey, endKey, &pebble.WriteOptions{Sync: false}) - } - } - handles := makePendingIndexHandlesWithCapacity(0) - for _, indexInfo := range tbl.Meta().Indices { - if indexInfo.State != model.StatePublic { - continue + h, err := m.decoder.DecodeHandleFromRowKey(key) + if err != nil { + return errors.Trace(err) } - ranges := ranger.FullRange() - var keysRanges []tidbkv.KeyRange - for _, id := range tableIDs { - partitionKeysRanges, err := distsql.IndexRangesToKVRanges(nil, id, indexInfo.ID, ranges, nil) - if err != nil { - return false, err - } - keysRanges = append(keysRanges, partitionKeysRanges...) + conflictInfo := errormanager.DataConflictInfo{ + RawKey: key, + RawValue: val, + KeyData: h.String(), + Row: m.decoder.DecodeRawRowDataAsStr(h, val), } - allRanges = append(allRanges, keysRanges...) - for _, r := range keysRanges { - tableID := tablecodec.DecodeTableID(r.StartKey) - startKey := codec.EncodeBytes([]byte{}, r.StartKey) - endKey := codec.EncodeBytes([]byte{}, r.EndKey) - opts := &pebble.IterOptions{ - LowerBound: startKey, - UpperBound: endKey, - } - indexLogger := logger.With( - zap.Int64("tableID", tableID), - zap.String("index", indexInfo.Name.O), - zap.Int64("indexID", indexInfo.ID), - logutil.Key("startKey", startKey), - logutil.Key("endKey", endKey)) - indexLogger.Info("[detect-dupe] collect index from db") - - if err := func() error { - iter := db.NewIter(opts) - defer iter.Close() - - for iter.First(); iter.Valid(); iter.Next() { - hasDataConflict = true - rawKey, err := manager.keyAdapter.Decode(nil, iter.Key()) - if err != nil { - indexLogger.Error( - "[detect-dupe] decode key error when query handle for duplicate index", - zap.Binary("key", iter.Key()), - ) - return err - } - rawValue := make([]byte, len(iter.Value())) - copy(rawValue, iter.Value()) - h, err := decoder.DecodeHandleFromIndex(indexInfo, rawKey, rawValue) - if err != nil { - indexLogger.Error("[detect-dupe] decode handle error from index for duplicatedb", - zap.Error(err), logutil.Key("rawKey", rawKey), - logutil.Key("value", rawValue)) - return err - } - indexLogger.Debug("[detect-dupe] found local index conflict, stashing", - logutil.Key("key", rawKey), - logutil.Key("value", rawValue), - logutil.Redact(zap.Stringer("handle", h))) - handles.append( - errormanager.DataConflictInfo{ - RawKey: rawKey, - RawValue: rawValue, - KeyData: h.String(), - }, - indexInfo.Name.O, - h, - decoder.EncodeHandleKey(tableID, h)) - if handles.Len() > maxGetRequestKeyCount { - handles = manager.getValues(ctx, decoder, handles) - } - } - if handles.Len() > 0 { - handles = manager.getValues(ctx, decoder, handles) - } - if handles.Len() == 0 { - db.DeleteRange(startKey, endKey, &pebble.WriteOptions{Sync: false}) - } - return nil - }(); err != nil { - return false, errors.Trace(err) + dataConflictInfos = append(dataConflictInfos, conflictInfo) + if len(dataConflictInfos) >= defaultRecordConflictErrorBatch { + if err := m.errorMgr.RecordDataConflictError(ctx, m.logger, m.tableName, dataConflictInfos); err != nil { + return errors.Trace(err) } + dataConflictInfos = dataConflictInfos[:0] } } - - for i := 0; i < maxRetryTimes && handles.Len() > 0; i++ { - handles = manager.getValues(ctx, decoder, handles) - } - if handles.Len() > 0 { - return false, errors.Errorf("retry getValues time exceed limit") - } - for _, r := range allRanges { - startKey := codec.EncodeBytes([]byte{}, r.StartKey) - endKey := codec.EncodeBytes([]byte{}, r.EndKey) - db.DeleteRange(startKey, endKey, &pebble.WriteOptions{Sync: false}) + if len(dataConflictInfos) > 0 { + if err := m.errorMgr.RecordDataConflictError(ctx, m.logger, m.tableName, dataConflictInfos); err != nil { + return errors.Trace(err) + } } - return hasDataConflict, nil + return nil } -func (manager *DuplicateManager) getValues( - ctx context.Context, - decoder *kv.TableKVDecoder, - handles pendingIndexHandles, -) pendingIndexHandles { - var finalErr error - logger := log.With( - zap.String("table", decoder.Name()), - zap.Int("handlesCount", handles.Len()), - ).Begin(zap.DebugLevel, "[detect-dupe] collect values from TiKV") - defer func() { - logger.End(zap.ErrorLevel, finalErr) - }() - - // TODO: paginate the handles. - snapshot := manager.tikvCli.GetSnapshot(math.MaxUint64) +func (m *DuplicateManager) saveIndexHandles(ctx context.Context, handles pendingIndexHandles) error { + snapshot := m.tikvCli.GetSnapshot(math.MaxUint64) batchGetMap, err := snapshot.BatchGet(ctx, handles.rawHandles) if err != nil { - finalErr = err - return handles + return errors.Trace(err) } - retryHandles := makePendingIndexHandlesWithCapacity(0) - batch := makePendingIndexHandlesWithCapacity(handles.Len()) - rawRows := make([][]byte, 0, handles.Len()) + rawRows := make([][]byte, handles.Len()) for i, rawHandle := range handles.rawHandles { - rawValue, ok := batchGetMap[string(rawHandle)] + rawValue, ok := batchGetMap[string(hack.String(rawHandle))] if ok { - logger.Debug("[detect-dupe] retrieved value from TiKV", - logutil.Key("rawHandle", rawHandle), - logutil.Key("row", rawValue)) - rawRows = append(rawRows, rawValue) - handles.dataConflictInfos[i].Row = decoder.DecodeRawRowDataAsStr(handles.handles[i], rawValue) - batch.appendAt(&handles, i) + rawRows[i] = rawValue + handles.dataConflictInfos[i].Row = m.decoder.DecodeRawRowDataAsStr(handles.handles[i], rawValue) } else { - logger.Warn("[detect-dupe] missing value from TiKV, will retry", + m.logger.Warn("[detect-dupe] can not found row data corresponding to the handle", logutil.Key("rawHandle", rawHandle)) - retryHandles.appendAt(&handles, i) } } - finalErr = manager.errorMgr.RecordIndexConflictError( - ctx, log.L(), - decoder.Name(), - batch.indexNames, - batch.dataConflictInfos, - batch.rawHandles, - rawRows) - if finalErr != nil { - return handles - } - - return retryHandles + err = m.errorMgr.RecordIndexConflictError(ctx, m.logger, m.tableName, + handles.indexNames, handles.dataConflictInfos, handles.rawHandles, rawRows) + return errors.Trace(err) } -func (manager *DuplicateManager) getDuplicateStream(ctx context.Context, - region *restore.RegionInfo, - start []byte, end []byte) (import_sstpb.ImportSST_DuplicateDetectClient, error) { - leader := region.Leader - if leader == nil { - leader = region.Region.GetPeers()[0] - } +// RecordIndexConflictError records index conflicts to errorMgr. The key received from stream must be an index key. +func (m *DuplicateManager) RecordIndexConflictError(ctx context.Context, stream DupKVStream, tableID int64, indexInfo *model.IndexInfo) error { + defer stream.Close() + indexHandles := makePendingIndexHandlesWithCapacity(0) + for { + key, val, err := stream.Next() + if errors.Cause(err) == io.EOF { + break + } + if err != nil { + return errors.Trace(err) + } + m.hasDupe.Store(true) - cli, err := manager.getImportClient(ctx, leader) - if err != nil { - return nil, err - } + h, err := m.decoder.DecodeHandleFromIndex(indexInfo, key, val) + if err != nil { + return errors.Trace(err) + } + conflictInfo := errormanager.DataConflictInfo{ + RawKey: key, + RawValue: val, + KeyData: h.String(), + } + indexHandles.append(conflictInfo, indexInfo.Name.O, + h, tablecodec.EncodeRowKeyWithHandle(tableID, h)) - reqCtx := &kvrpcpb.Context{ - RegionId: region.Region.GetId(), - RegionEpoch: region.Region.GetRegionEpoch(), - Peer: leader, + if indexHandles.Len() >= defaultRecordConflictErrorBatch { + if err := m.saveIndexHandles(ctx, indexHandles); err != nil { + return errors.Trace(err) + } + indexHandles.truncate() + } } - req := &import_sstpb.DuplicateDetectRequest{ - Context: reqCtx, - StartKey: start, - EndKey: end, - KeyOnly: false, + if indexHandles.Len() > 0 { + if err := m.saveIndexHandles(ctx, indexHandles); err != nil { + return errors.Trace(err) + } } - stream, err := cli.DuplicateDetect(ctx, req) - return stream, err + return nil } -func (manager *DuplicateManager) getImportClient(ctx context.Context, peer *metapb.Peer) (import_sstpb.ImportSSTClient, error) { - conn, err := manager.connPool.GetGrpcConn(ctx, peer.GetStoreId(), 1, func(ctx context.Context) (*grpc.ClientConn, error) { - return manager.makeConn(ctx, peer.GetStoreId()) - }) - if err != nil { - return nil, err - } - return import_sstpb.NewImportSSTClient(conn), nil +type dupTask struct { + tidbkv.KeyRange + tableID int64 + indexInfo *model.IndexInfo } -func (manager *DuplicateManager) makeConn(ctx context.Context, storeID uint64) (*grpc.ClientConn, error) { - store, err := manager.splitCli.GetStore(ctx, storeID) +func (m *DuplicateManager) buildDupTasks() ([]dupTask, error) { + var tasks []dupTask + keyRanges, err := tableHandleKeyRanges(m.tbl.Meta()) if err != nil { return nil, errors.Trace(err) } - opt := grpc.WithInsecure() - if manager.tls.TLSConfig() != nil { - opt = grpc.WithTransportCredentials(credentials.NewTLS(manager.tls.TLSConfig())) - } - ctx, cancel := context.WithTimeout(ctx, dialTimeout) - - bfConf := backoff.DefaultConfig - bfConf.MaxDelay = gRPCBackOffMaxDelay - // we should use peer address for tiflash. for tikv, peer address is empty - addr := store.GetPeerAddress() - if addr == "" { - addr = store.GetAddress() - } - conn, err := grpc.DialContext( - ctx, - addr, - opt, - grpc.WithConnectParams(grpc.ConnectParams{Backoff: bfConf}), - grpc.WithKeepaliveParams(keepalive.ClientParameters{ - Time: gRPCKeepAliveTime, - Timeout: gRPCKeepAliveTimeout, - PermitWithoutStream: true, - }), - ) - cancel() + for _, kr := range keyRanges { + tableID := tablecodec.DecodeTableID(kr.StartKey) + tasks = append(tasks, dupTask{ + KeyRange: kr, + tableID: tableID, + }) + } + for _, indexInfo := range m.tbl.Meta().Indices { + if indexInfo.State != model.StatePublic { + continue + } + keyRanges, err = tableIndexKeyRanges(m.tbl.Meta(), indexInfo) + if err != nil { + return nil, errors.Trace(err) + } + for _, kr := range keyRanges { + tableID := tablecodec.DecodeTableID(kr.StartKey) + tasks = append(tasks, dupTask{ + KeyRange: kr, + tableID: tableID, + indexInfo: indexInfo, + }) + } + } + return tasks, nil +} + +func (m *DuplicateManager) splitLocalDupTaskByKeys( + task dupTask, + dupDB *pebble.DB, + keyAdapter KeyAdapter, + sizeLimit int64, + keysLimit int64, +) ([]dupTask, error) { + sizeProps, err := getSizeProperties(m.logger, dupDB, keyAdapter) if err != nil { return nil, errors.Trace(err) } - return conn, nil + ranges := splitRangeBySizeProps(Range{start: task.StartKey, end: task.EndKey}, sizeProps, sizeLimit, keysLimit) + var newDupTasks []dupTask + for _, r := range ranges { + newDupTasks = append(newDupTasks, dupTask{ + KeyRange: tidbkv.KeyRange{ + StartKey: r.start, + EndKey: r.end, + }, + tableID: task.tableID, + indexInfo: task.indexInfo, + }) + } + return newDupTasks, nil } -func buildDuplicateRequests(tableInfo *model.TableInfo) ([]*DuplicateRequest, error) { - var reqs []*DuplicateRequest - for _, id := range physicalTableIDs(tableInfo) { - tableReqs, err := buildTableRequests(id, tableInfo.IsCommonHandle) +func (m *DuplicateManager) buildLocalDupTasks(dupDB *pebble.DB, keyAdapter KeyAdapter) ([]dupTask, error) { + tasks, err := m.buildDupTasks() + if err != nil { + return nil, errors.Trace(err) + } + var newTasks []dupTask + for _, task := range tasks { + // FIXME: Do not hardcode sizeLimit and keysLimit. + subTasks, err := m.splitLocalDupTaskByKeys(task, dupDB, keyAdapter, 32*units.MiB, 1*units.MiB) if err != nil { return nil, errors.Trace(err) } - reqs = append(reqs, tableReqs...) - for _, indexInfo := range tableInfo.Indices { - if indexInfo.State != model.StatePublic { - continue + newTasks = append(newTasks, subTasks...) + } + return newTasks, nil +} + +// CollectDuplicateRowsFromDupDB collects duplicates from the duplicate DB and records all duplicate row info into errorMgr. +func (m *DuplicateManager) CollectDuplicateRowsFromDupDB(ctx context.Context, dupDB *pebble.DB, keyAdapter KeyAdapter) error { + tasks, err := m.buildLocalDupTasks(dupDB, keyAdapter) + if err != nil { + return errors.Trace(err) + } + logger := m.logger + logger.Info("[detect-dupe] collect duplicate rows from local duplicate db", zap.Int("tasks", len(tasks))) + + pool := utils.NewWorkerPool(uint(m.concurrency), "collect duplicate rows from duplicate db") + g, gCtx := errgroup.WithContext(ctx) + for _, task := range tasks { + task := task + pool.ApplyOnErrorGroup(g, func() error { + if err := common.Retry("collect local duplicate rows", logger, func() error { + stream := NewLocalDupKVStream(dupDB, keyAdapter, task.KeyRange) + var err error + if task.indexInfo == nil { + err = m.RecordDataConflictError(gCtx, stream) + } else { + err = m.RecordIndexConflictError(gCtx, stream, task.tableID, task.indexInfo) + } + return errors.Trace(err) + }); err != nil { + return errors.Trace(err) + } + + // Delete the key range in duplicate DB since we have the duplicates have been collected. + rawStartKey := keyAdapter.Encode(nil, task.StartKey, math.MinInt64) + rawEndKey := keyAdapter.Encode(nil, task.EndKey, math.MinInt64) + err = dupDB.DeleteRange(rawStartKey, rawEndKey, nil) + return errors.Trace(err) + }) + } + return errors.Trace(g.Wait()) +} + +func (m *DuplicateManager) splitKeyRangeByRegions( + ctx context.Context, keyRange tidbkv.KeyRange, +) ([]*restore.RegionInfo, []tidbkv.KeyRange, error) { + rawStartKey := codec.EncodeBytes(nil, keyRange.StartKey) + rawEndKey := codec.EncodeBytes(nil, keyRange.EndKey) + allRegions, err := restore.PaginateScanRegion(ctx, m.splitCli, rawStartKey, rawEndKey, 1024) + if err != nil { + return nil, nil, errors.Trace(err) + } + regions := make([]*restore.RegionInfo, 0, len(allRegions)) + keyRanges := make([]tidbkv.KeyRange, 0, len(allRegions)) + for _, region := range allRegions { + startKey := keyRange.StartKey + endKey := keyRange.EndKey + if len(region.Region.StartKey) > 0 { + _, regionStartKey, err := codec.DecodeBytes(region.Region.StartKey, nil) + if err != nil { + return nil, nil, errors.Trace(err) } - indexReqs, err := buildIndexRequests(id, indexInfo) + if bytes.Compare(startKey, regionStartKey) < 0 { + startKey = regionStartKey + } + } + if len(region.Region.EndKey) > 0 { + _, regionEndKey, err := codec.DecodeBytes(region.Region.EndKey, nil) if err != nil { - return nil, errors.Trace(err) + return nil, nil, errors.Trace(err) } - reqs = append(reqs, indexReqs...) + if bytes.Compare(endKey, regionEndKey) > 0 { + endKey = regionEndKey + } + } + if bytes.Compare(startKey, endKey) < 0 { + regions = append(regions, region) + keyRanges = append(keyRanges, tidbkv.KeyRange{ + StartKey: startKey, + EndKey: endKey, + }) } } - return reqs, nil + return regions, keyRanges, nil } -func buildTableRequests(tableID int64, isCommonHandle bool) ([]*DuplicateRequest, error) { - ranges := ranger.FullIntRange(false) - if isCommonHandle { - ranges = ranger.FullRange() +func (m *DuplicateManager) processRemoteDupTaskOnce( + ctx context.Context, + task dupTask, + logger log.Logger, + importClientFactory ImportClientFactory, + regionPool *utils.WorkerPool, + remainKeyRanges *pendingKeyRanges, +) (madeProgress bool, err error) { + var ( + regions []*restore.RegionInfo + keyRanges []tidbkv.KeyRange + ) + for _, kr := range remainKeyRanges.list() { + subRegions, subKeyRanges, err := m.splitKeyRangeByRegions(ctx, kr) + if err != nil { + return false, errors.Trace(err) + } + regions = append(regions, subRegions...) + keyRanges = append(keyRanges, subKeyRanges...) } - keysRanges, err := distsql.TableHandleRangesToKVRanges(nil, []int64{tableID}, isCommonHandle, ranges, nil) - if err != nil { - return nil, errors.Trace(err) + + var metErr common.OnceError + wg := &sync.WaitGroup{} + atomicMadeProgress := atomic.NewBool(false) + for i := 0; i < len(regions); i++ { + if ctx.Err() != nil { + metErr.Set(ctx.Err()) + break + } + region := regions[i] + kr := keyRanges[i] + wg.Add(1) + regionPool.Apply(func() { + defer wg.Done() + + logger := logger.With( + zap.Uint64("regionID", region.Region.Id), + logutil.Key("dupDetectStartKey", kr.StartKey), + logutil.Key("dupDetectEndKey", kr.EndKey), + ) + err := func() error { + stream, err := NewRemoteDupKVStream(ctx, region, kr, importClientFactory) + if err != nil { + return errors.Annotatef(err, "failed to create remote duplicate kv stream") + } + if task.indexInfo == nil { + err = m.RecordDataConflictError(ctx, stream) + } else { + err = m.RecordIndexConflictError(ctx, stream, task.tableID, task.indexInfo) + } + if err != nil { + return errors.Annotatef(err, "failed to record conflict errors") + } + return nil + }() + if err != nil { + if regionErr, ok := errors.Cause(err).(regionError); ok { + logger.Debug("[detect-dupe] collect duplicate rows from region failed due to region error", zap.Error(regionErr)) + } else { + logger.Warn("[detect-dupe] collect duplicate rows from region failed", log.ShortError(err)) + } + metErr.Set(err) + } else { + logger.Debug("[detect-dupe] collect duplicate rows from region completed") + remainKeyRanges.finish(kr) + atomicMadeProgress.Store(true) + } + }) } - reqs := make([]*DuplicateRequest, 0) - for _, r := range keysRanges { - req := &DuplicateRequest{ - start: r.StartKey, - end: r.EndKey, - tableID: tableID, - indexInfo: nil, + wg.Wait() + return atomicMadeProgress.Load(), errors.Trace(metErr.Get()) +} + +// processRemoteDupTask processes a remoteDupTask. A task contains a key range. +// A key range is associated with multiple regions. processRemoteDupTask tries +// to collect duplicates from each region. +func (m *DuplicateManager) processRemoteDupTask( + ctx context.Context, + task dupTask, + logger log.Logger, + importClientFactory ImportClientFactory, + regionPool *utils.WorkerPool, +) error { + remainAttempts := maxDupCollectAttemptTimes + remainKeyRanges := newPendingKeyRanges(task.KeyRange) + for { + madeProgress, err := m.processRemoteDupTaskOnce(ctx, task, logger, importClientFactory, regionPool, remainKeyRanges) + if err == nil { + if !remainKeyRanges.empty() { + remainKeyRanges.list() + logger.Panic("[detect-dupe] there are still some key ranges that haven't been processed, which is unexpected", + zap.Any("remainKeyRanges", remainKeyRanges.list())) + } + return nil } - reqs = append(reqs, req) + if log.IsContextCanceledError(err) { + return errors.Trace(err) + } + if !madeProgress { + remainAttempts-- + if remainAttempts <= 0 { + logger.Error("[detect-dupe] all attempts to process the remote dupTask have failed", log.ShortError(err)) + return errors.Trace(err) + } + } + logger.Warn("[detect-dupe] process remote dupTask encounters error, retrying", + log.ShortError(err), zap.Int("remainAttempts", remainAttempts)) } - return reqs, nil } -func buildIndexRequests(tableID int64, indexInfo *model.IndexInfo) ([]*DuplicateRequest, error) { - ranges := ranger.FullRange() - keysRanges, err := distsql.IndexRangesToKVRanges(nil, tableID, indexInfo.ID, ranges, nil) +// CollectDuplicateRowsFromTiKV collects duplicates from the remote TiKV and records all duplicate row info into errorMgr. +func (m *DuplicateManager) CollectDuplicateRowsFromTiKV(ctx context.Context, importClientFactory ImportClientFactory) error { + tasks, err := m.buildDupTasks() if err != nil { - return nil, errors.Trace(err) - } - reqs := make([]*DuplicateRequest, 0) - for _, r := range keysRanges { - req := &DuplicateRequest{ - start: r.StartKey, - end: r.EndKey, - tableID: tableID, - indexInfo: indexInfo, - } - reqs = append(reqs, req) + return errors.Trace(err) + } + logger := m.logger + logger.Info("[detect-dupe] collect duplicate rows from tikv", zap.Int("tasks", len(tasks))) + + taskPool := utils.NewWorkerPool(uint(m.concurrency), "collect duplicate rows from tikv") + regionPool := utils.NewWorkerPool(uint(m.concurrency), "collect duplicate rows from tikv by region") + g, gCtx := errgroup.WithContext(ctx) + for _, task := range tasks { + task := task + taskPool.ApplyOnErrorGroup(g, func() error { + taskLogger := logger.With( + logutil.Key("startKey", task.StartKey), + logutil.Key("endKey", task.EndKey), + zap.Int64("tableID", task.tableID), + ) + if task.indexInfo != nil { + taskLogger = taskLogger.With( + zap.String("indexName", task.indexInfo.Name.O), + zap.Int64("indexID", task.indexInfo.ID), + ) + } + err := m.processRemoteDupTask(gCtx, task, taskLogger, importClientFactory, regionPool) + return errors.Trace(err) + }) } - return reqs, nil + return errors.Trace(g.Wait()) } diff --git a/br/pkg/lightning/backend/local/engine.go b/br/pkg/lightning/backend/local/engine.go index 18be072afc054..e644b1abc938a 100644 --- a/br/pkg/lightning/backend/local/engine.go +++ b/br/pkg/lightning/backend/local/engine.go @@ -73,10 +73,6 @@ type engineMeta struct { Length atomic.Int64 `json:"length"` // TotalSize is the total pre-compressed KV byte size stored by engine. TotalSize atomic.Int64 `json:"total_size"` - // Duplicates is the number of duplicates kv pairs detected when importing. Note that the value is - // probably larger than real value, because we may import same range more than once. For accurate - // information, you should iterate the duplicate db after import is finished. - Duplicates atomic.Int64 `json:"duplicates"` } type syncedRanges struct { @@ -257,28 +253,6 @@ var _ btree.Item = &rangeProperty{} type rangeProperties []rangeProperty -func decodeRangeProperties(data []byte) (rangeProperties, error) { - r := make(rangeProperties, 0, 16) - for len(data) > 0 { - if len(data) < 4 { - return nil, io.ErrUnexpectedEOF - } - keyLen := int(binary.BigEndian.Uint32(data[:4])) - data = data[4:] - if len(data) < keyLen+8*2 { - return nil, io.ErrUnexpectedEOF - } - key := data[:keyLen] - data = data[keyLen:] - size := binary.BigEndian.Uint64(data[:8]) - keys := binary.BigEndian.Uint64(data[8:]) - data = data[16:] - r = append(r, rangeProperty{Key: key, rangeOffsets: rangeOffsets{Size: size, Keys: keys}}) - } - - return r, nil -} - func (r rangeProperties) Encode() []byte { b := make([]byte, 0, 1024) idx := 0 @@ -340,6 +314,9 @@ func (c *RangePropertiesCollector) insertNewPoint(key []byte) { // Add implements `pebble.TablePropertyCollector`. // Add implements `TablePropertyCollector.Add`. func (c *RangePropertiesCollector) Add(key pebble.InternalKey, value []byte) error { + if key.Kind() != pebble.InternalKeyKindSet || bytes.Equal(key.UserKey, engineMetaKey) { + return nil + } c.currentOffsets.Size += uint64(len(value)) + uint64(len(key.UserKey)) c.currentOffsets.Keys++ if len(c.lastKey) == 0 || c.sizeInLastRange() >= c.propSizeIdxDistance || @@ -390,7 +367,7 @@ func (s *sizeProperties) addAll(props rangeProperties) { prevRange = r.rangeOffsets } if len(props) > 0 { - s.totalSize = props[len(props)-1].Size + s.totalSize += props[len(props)-1].Size } } @@ -402,10 +379,38 @@ func (s *sizeProperties) iter(f func(p *rangeProperty) bool) { }) } -func (e *Engine) getSizeProperties() (*sizeProperties, error) { - sstables, err := e.db.SSTables(pebble.WithProperties()) +func decodeRangeProperties(data []byte, keyAdapter KeyAdapter) (rangeProperties, error) { + r := make(rangeProperties, 0, 16) + for len(data) > 0 { + if len(data) < 4 { + return nil, io.ErrUnexpectedEOF + } + keyLen := int(binary.BigEndian.Uint32(data[:4])) + data = data[4:] + if len(data) < keyLen+8*2 { + return nil, io.ErrUnexpectedEOF + } + key := data[:keyLen] + data = data[keyLen:] + size := binary.BigEndian.Uint64(data[:8]) + keys := binary.BigEndian.Uint64(data[8:]) + data = data[16:] + if !bytes.Equal(key, engineMetaKey) { + userKey, err := keyAdapter.Decode(nil, key) + if err != nil { + return nil, errors.Annotate(err, "failed to decode key with keyAdapter") + } + r = append(r, rangeProperty{Key: userKey, rangeOffsets: rangeOffsets{Size: size, Keys: keys}}) + } + } + + return r, nil +} + +func getSizeProperties(logger log.Logger, db *pebble.DB, keyAdapter KeyAdapter) (*sizeProperties, error) { + sstables, err := db.SSTables(pebble.WithProperties()) if err != nil { - log.L().Warn("get table properties failed", zap.Stringer("engine", e.UUID), log.ShortError(err)) + logger.Warn("get sst table properties failed", log.ShortError(err)) return nil, errors.Trace(err) } @@ -414,31 +419,12 @@ func (e *Engine) getSizeProperties() (*sizeProperties, error) { for _, info := range level { if prop, ok := info.Properties.UserProperties[propRangeIndex]; ok { data := hack.Slice(prop) - rangeProps, err := decodeRangeProperties(data) + rangeProps, err := decodeRangeProperties(data, keyAdapter) if err != nil { - log.L().Warn("decodeRangeProperties failed", zap.Stringer("engine", e.UUID), + logger.Warn("decodeRangeProperties failed", zap.Stringer("fileNum", info.FileNum), log.ShortError(err)) return nil, errors.Trace(err) } - if e.duplicateDetection { - newRangeProps := make(rangeProperties, 0, len(rangeProps)) - for _, p := range rangeProps { - if !bytes.Equal(p.Key, engineMetaKey) { - p.Key, err = e.keyAdapter.Decode(nil, p.Key) - if err != nil { - log.L().Warn( - "decodeRangeProperties failed because the props key is invalid", - zap.Stringer("engine", e.UUID), - zap.Stringer("fileNum", info.FileNum), - zap.Binary("key", p.Key), - ) - return nil, errors.Trace(err) - } - newRangeProps = append(newRangeProps, p) - } - } - rangeProps = newRangeProps - } sizeProps.addAll(rangeProps) } } diff --git a/br/pkg/lightning/backend/local/local.go b/br/pkg/lightning/backend/local/local.go index babb5a328000f..b4a5826e5651f 100644 --- a/br/pkg/lightning/backend/local/local.go +++ b/br/pkg/lightning/backend/local/local.go @@ -57,6 +57,7 @@ import ( "github.com/pingcap/tidb/table" "github.com/pingcap/tidb/tablecodec" "github.com/pingcap/tidb/util/codec" + tikverror "github.com/tikv/client-go/v2/error" "github.com/tikv/client-go/v2/oracle" tikvclient "github.com/tikv/client-go/v2/tikv" pd "github.com/tikv/pd/client" @@ -64,6 +65,7 @@ import ( "go.uber.org/multierr" "go.uber.org/zap" "golang.org/x/sync/errgroup" + "golang.org/x/time/rate" "google.golang.org/grpc" "google.golang.org/grpc/backoff" "google.golang.org/grpc/codes" @@ -111,8 +113,82 @@ var ( errorEngineClosed = errors.New("engine is closed") ) -// getImportClientFn is a variable alias for getImportClient used for unit test. -var getImportClientFn = getImportClient +// ImportClientFactory is factory to create new import client for specific store. +type ImportClientFactory interface { + Create(ctx context.Context, storeID uint64) (sst.ImportSSTClient, error) + Close() +} + +type importClientFactoryImpl struct { + conns *common.GRPCConns + splitCli split.SplitClient + tls *common.TLS + tcpConcurrency int +} + +func newImportClientFactoryImpl(splitCli split.SplitClient, tls *common.TLS, tcpConcurrency int) *importClientFactoryImpl { + return &importClientFactoryImpl{ + conns: common.NewGRPCConns(), + splitCli: splitCli, + tls: tls, + tcpConcurrency: tcpConcurrency, + } +} + +func (f *importClientFactoryImpl) makeConn(ctx context.Context, storeID uint64) (*grpc.ClientConn, error) { + store, err := f.splitCli.GetStore(ctx, storeID) + if err != nil { + return nil, errors.Trace(err) + } + opt := grpc.WithInsecure() + if f.tls.TLSConfig() != nil { + opt = grpc.WithTransportCredentials(credentials.NewTLS(f.tls.TLSConfig())) + } + ctx, cancel := context.WithTimeout(ctx, dialTimeout) + + bfConf := backoff.DefaultConfig + bfConf.MaxDelay = gRPCBackOffMaxDelay + // we should use peer address for tiflash. for tikv, peer address is empty + addr := store.GetPeerAddress() + if addr == "" { + addr = store.GetAddress() + } + conn, err := grpc.DialContext( + ctx, + addr, + opt, + grpc.WithConnectParams(grpc.ConnectParams{Backoff: bfConf}), + grpc.WithKeepaliveParams(keepalive.ClientParameters{ + Time: gRPCKeepAliveTime, + Timeout: gRPCKeepAliveTimeout, + PermitWithoutStream: true, + }), + ) + cancel() + if err != nil { + return nil, errors.Trace(err) + } + return conn, nil +} + +func (f *importClientFactoryImpl) getGrpcConn(ctx context.Context, storeID uint64) (*grpc.ClientConn, error) { + return f.conns.GetGrpcConn(ctx, storeID, f.tcpConcurrency, + func(ctx context.Context) (*grpc.ClientConn, error) { + return f.makeConn(ctx, storeID) + }) +} + +func (f *importClientFactoryImpl) Create(ctx context.Context, storeID uint64) (sst.ImportSSTClient, error) { + conn, err := f.getGrpcConn(ctx, storeID) + if err != nil { + return nil, err + } + return sst.NewImportSSTClient(conn), nil +} + +func (f *importClientFactoryImpl) Close() { + f.conns.Close() +} // Range record start and end key for localStoreDir.DB // so we can write it to tikv in streaming @@ -125,7 +201,6 @@ type local struct { engines sync.Map // sync version of map[uuid.UUID]*Engine pdCtl *pdutil.PdController - conns common.GRPCConns splitCli split.SplitClient tikvCli *tikvclient.KVStore tls *common.TLS @@ -139,17 +214,19 @@ type local struct { batchWriteKVPairs int checkpointEnabled bool - tcpConcurrency int - maxOpenFiles int + dupeConcurrency int + maxOpenFiles int engineMemCacheSize int localWriterMemCacheSize int64 supportMultiIngest bool - checkTiKVAvaliable bool - duplicateDetection bool - duplicateDB *pebble.DB - errorMgr *errormanager.ErrorManager + checkTiKVAvaliable bool + duplicateDetection bool + duplicateDB *pebble.DB + keyAdapter KeyAdapter + errorMgr *errormanager.ErrorManager + importClientFactory ImportClientFactory bufferPool *membuf.Pool } @@ -220,7 +297,12 @@ func NewLocalBackend( if err != nil { return backend.MakeBackend(nil), err } - + importClientFactory := newImportClientFactoryImpl(splitCli, tls, rangeConcurrency) + duplicateDetection := cfg.TikvImporter.DuplicateResolution != config.DupeResAlgNone + keyAdapter := KeyAdapter(noopKeyAdapter{}) + if duplicateDetection { + keyAdapter = dupDetectKeyAdapter{} + } local := &local{ engines: sync.Map{}, pdCtl: pdCtl, @@ -233,21 +315,21 @@ func NewLocalBackend( localStoreDir: localFile, rangeConcurrency: worker.NewPool(ctx, rangeConcurrency, "range"), ingestConcurrency: worker.NewPool(ctx, rangeConcurrency*2, "ingest"), - tcpConcurrency: rangeConcurrency, + dupeConcurrency: rangeConcurrency * 2, batchWriteKVPairs: cfg.TikvImporter.SendKVPairs, checkpointEnabled: cfg.Checkpoint.Enable, maxOpenFiles: utils.MaxInt(maxOpenFiles, openFilesLowerThreshold), engineMemCacheSize: int(cfg.TikvImporter.EngineMemCacheSize), localWriterMemCacheSize: int64(cfg.TikvImporter.LocalWriterMemCacheSize), - duplicateDetection: cfg.TikvImporter.DuplicateResolution != config.DupeResAlgNone, + duplicateDetection: duplicateDetection, checkTiKVAvaliable: cfg.App.CheckRequirements, duplicateDB: duplicateDB, + keyAdapter: keyAdapter, errorMgr: errorMgr, - - bufferPool: membuf.NewPool(membuf.WithAllocator(manual.Allocator{})), + importClientFactory: importClientFactory, + bufferPool: membuf.NewPool(membuf.WithAllocator(manual.Allocator{})), } - local.conns = common.NewGRPCConns() if err = local.checkMultiIngestSupport(ctx); err != nil { return backend.MakeBackend(nil), err } @@ -372,49 +454,6 @@ func (local *local) lockAllEnginesUnless(newState, ignoreStateMask importMutexSt return allEngines } -func (local *local) makeConn(ctx context.Context, storeID uint64) (*grpc.ClientConn, error) { - store, err := local.splitCli.GetStore(ctx, storeID) - if err != nil { - return nil, errors.Trace(err) - } - opt := grpc.WithInsecure() - if local.tls.TLSConfig() != nil { - opt = grpc.WithTransportCredentials(credentials.NewTLS(local.tls.TLSConfig())) - } - ctx, cancel := context.WithTimeout(ctx, dialTimeout) - - bfConf := backoff.DefaultConfig - bfConf.MaxDelay = gRPCBackOffMaxDelay - // we should use peer address for tiflash. for tikv, peer address is empty - addr := store.GetPeerAddress() - if addr == "" { - addr = store.GetAddress() - } - conn, err := grpc.DialContext( - ctx, - addr, - opt, - grpc.WithConnectParams(grpc.ConnectParams{Backoff: bfConf}), - grpc.WithKeepaliveParams(keepalive.ClientParameters{ - Time: gRPCKeepAliveTime, - Timeout: gRPCKeepAliveTimeout, - PermitWithoutStream: true, - }), - ) - cancel() - if err != nil { - return nil, errors.Trace(err) - } - return conn, nil -} - -func (local *local) getGrpcConn(ctx context.Context, storeID uint64) (*grpc.ClientConn, error) { - return local.conns.GetGrpcConn(ctx, storeID, local.tcpConcurrency, - func(ctx context.Context) (*grpc.ClientConn, error) { - return local.makeConn(ctx, storeID) - }) -} - // Close the local backend. func (local *local) Close() { allEngines := local.lockAllEnginesUnless(importMutexStateClose, 0) @@ -424,11 +463,11 @@ func (local *local) Close() { engine.Close() engine.unlock() } - local.conns.Close() + local.importClientFactory.Close() local.bufferPool.Destroy() if local.duplicateDB != nil { - // Check whether there are duplicates. + // Check if there are duplicates that are not collected. iter := local.duplicateDB.NewIter(&pebble.IterOptions{}) hasDuplicates := iter.First() allIsWell := true @@ -444,7 +483,7 @@ func (local *local) Close() { log.L().Warn("close duplicate db failed", zap.Error(err)) allIsWell = false } - // If checkpoint is disabled or we don't detect any duplicate, then this duplicate + // If checkpoint is disabled, or we don't detect any duplicate, then this duplicate // db dir will be useless, so we clean up this dir. if allIsWell && (!local.checkpointEnabled || !hasDuplicates) { if err := os.RemoveAll(filepath.Join(local.localStoreDir, duplicateDBName)); err != nil { @@ -559,10 +598,6 @@ func (local *local) OpenEngine(ctx context.Context, cfg *backend.EngineConfig, e } engineCtx, cancel := context.WithCancel(ctx) - keyAdapter := KeyAdapter(noopKeyAdapter{}) - if local.duplicateDetection { - keyAdapter = dupDetectKeyAdapter{} - } e, _ := local.engines.LoadOrStore(engineUUID, &Engine{ UUID: engineUUID, sstDir: sstDir, @@ -574,7 +609,7 @@ func (local *local) OpenEngine(ctx context.Context, cfg *backend.EngineConfig, e duplicateDetection: local.duplicateDetection, duplicateDB: local.duplicateDB, errorMgr: local.errorMgr, - keyAdapter: keyAdapter, + keyAdapter: local.keyAdapter, }) engine := e.(*Engine) engine.db = db @@ -619,16 +654,12 @@ func (local *local) CloseEngine(ctx context.Context, cfg *backend.EngineConfig, db: db, sstMetasChan: make(chan metaOrFlush), tableInfo: cfg.TableInfo, + keyAdapter: local.keyAdapter, duplicateDetection: local.duplicateDetection, duplicateDB: local.duplicateDB, errorMgr: local.errorMgr, } engine.sstIngester = dbSSTIngester{e: engine} - keyAdapter := KeyAdapter(noopKeyAdapter{}) - if local.duplicateDetection { - keyAdapter = dupDetectKeyAdapter{} - } - engine.keyAdapter = keyAdapter if err = engine.loadEngineMeta(); err != nil { return err } @@ -660,15 +691,7 @@ func (local *local) CloseEngine(ctx context.Context, cfg *backend.EngineConfig, } func (local *local) getImportClient(ctx context.Context, storeID uint64) (sst.ImportSSTClient, error) { - return getImportClientFn(local, ctx, storeID) -} - -func getImportClient(local *local, ctx context.Context, storeID uint64) (sst.ImportSSTClient, error) { - conn, err := local.getGrpcConn(ctx, storeID) - if err != nil { - return nil, err - } - return sst.NewImportSSTClient(conn), nil + return local.importClientFactory.Create(ctx, storeID) } type rangeStats struct { @@ -924,30 +947,32 @@ func splitRangeBySizeProps(fullRange Range, sizeProps *sizeProperties, sizeLimit curSize := uint64(0) curKeys := uint64(0) curKey := fullRange.start + sizeProps.iter(func(p *rangeProperty) bool { - if bytes.Equal(p.Key, engineMetaKey) { + if bytes.Compare(p.Key, curKey) <= 0 { return true } + if bytes.Compare(p.Key, fullRange.end) > 0 { + return false + } curSize += p.Size curKeys += p.Keys if int64(curSize) >= sizeLimit || int64(curKeys) >= keysLimit { - // in case the sizeLimit or keysLimit is too small - endKey := p.Key - if bytes.Equal(curKey, endKey) { - endKey = nextKey(endKey) - } - ranges = append(ranges, Range{start: curKey, end: endKey}) - curKey = endKey + ranges = append(ranges, Range{start: curKey, end: p.Key}) + curKey = p.Key curSize = 0 curKeys = 0 } return true }) - if curKeys > 0 { - ranges = append(ranges, Range{start: curKey, end: fullRange.end}) - } else { - ranges[len(ranges)-1].end = fullRange.end + if bytes.Compare(curKey, fullRange.end) < 0 { + // If the remaining range is too small, append it to last range. + if len(ranges) > 0 && curKeys == 0 { + ranges[len(ranges)-1].end = fullRange.end + } else { + ranges = append(ranges, Range{start: curKey, end: fullRange.end}) + } } return ranges } @@ -986,7 +1011,8 @@ func (local *local) readAndSplitIntoRange(ctx context.Context, engine *Engine, r return ranges, nil } - sizeProps, err := engine.getSizeProperties() + logger := log.With(zap.Stringer("engine", engine.UUID)) + sizeProps, err := getSizeProperties(logger, engine.db, local.keyAdapter) if err != nil { return nil, errors.Trace(err) } @@ -994,7 +1020,7 @@ func (local *local) readAndSplitIntoRange(ctx context.Context, engine *Engine, r ranges := splitRangeBySizeProps(Range{start: firstKey, end: endKey}, sizeProps, regionSplitSize, regionSplitKeys) - log.L().Info("split engine key ranges", zap.Stringer("engine", engine.UUID), + logger.Info("split engine key ranges", zap.Int64("totalSize", engineFileTotalSize), zap.Int64("totalCount", engineFileLength), logutil.Key("firstKey", firstKey), logutil.Key("lastKey", lastKey), zap.Int("ranges", len(ranges))) @@ -1342,16 +1368,6 @@ func (local *local) ImportEngine(ctx context.Context, engineUUID uuid.UUID, regi } } - if lf.Duplicates.Load() > 0 { - if err := lf.saveEngineMeta(); err != nil { - log.L().Error("failed to save engine meta", log.ShortError(err)) - return err - } - log.L().Warn("duplicate detected during import engine", zap.Stringer("uuid", engineUUID), - zap.Int64("size", lfTotalSize), zap.Int64("kvs", lfLength), zap.Int64("duplicate-kvs", lf.Duplicates.Load()), - zap.Int64("importedSize", lf.importedKVSize.Load()), zap.Int64("importedCount", lf.importedKVCount.Load())) - } - log.L().Info("import engine success", zap.Stringer("uuid", engineUUID), zap.Int64("size", lfTotalSize), zap.Int64("kvs", lfLength), zap.Int64("importedSize", lf.importedKVSize.Load()), zap.Int64("importedCount", lf.importedKVCount.Load())) @@ -1359,48 +1375,39 @@ func (local *local) ImportEngine(ctx context.Context, engineUUID uuid.UUID, regi } func (local *local) CollectLocalDuplicateRows(ctx context.Context, tbl table.Table, tableName string, opts *kv.SessionOptions) (hasDupe bool, err error) { - if local.duplicateDB == nil { - return false, nil - } - - logger := log.With(zap.String("table", tableName)).Begin(zap.InfoLevel, "[detect-dupe] collect duplicate local keys") + logger := log.With(zap.String("table", tableName)).Begin(zap.InfoLevel, "[detect-dupe] collect local duplicate keys") defer func() { logger.End(zap.ErrorLevel, err) }() - physicalTS, logicalTS, err := local.pdCtl.GetPDClient().GetTS(ctx) + atomicHasDupe := atomic.NewBool(false) + duplicateManager, err := NewDuplicateManager(tbl, tableName, local.splitCli, local.tikvCli, + local.errorMgr, opts, local.dupeConcurrency, atomicHasDupe) if err != nil { - return false, err + return false, errors.Trace(err) } - ts := oracle.ComposeTS(physicalTS, logicalTS) - duplicateManager, err := NewDuplicateManager(local, ts, opts) - if err != nil { - return false, errors.Annotate(err, "open duplicatemanager failed") + if err := duplicateManager.CollectDuplicateRowsFromDupDB(ctx, local.duplicateDB, local.keyAdapter); err != nil { + return false, errors.Trace(err) } - hasDupe, err = duplicateManager.CollectDuplicateRowsFromLocalIndex(ctx, tbl, tableName, local.duplicateDB) - if err != nil { - return false, errors.Annotate(err, "collect local duplicate rows failed") - } - return hasDupe, nil + return atomicHasDupe.Load(), nil } -func (local *local) CollectRemoteDuplicateRows(ctx context.Context, tbl table.Table, tableName string, opts *kv.SessionOptions) (bool, error) { - log.L().Info("Begin collect remote duplicate keys", zap.String("table", tableName)) - physicalTS, logicalTS, err := local.pdCtl.GetPDClient().GetTS(ctx) - if err != nil { - return false, err - } - ts := oracle.ComposeTS(physicalTS, logicalTS) +func (local *local) CollectRemoteDuplicateRows(ctx context.Context, tbl table.Table, tableName string, opts *kv.SessionOptions) (hasDupe bool, err error) { + logger := log.With(zap.String("table", tableName)).Begin(zap.InfoLevel, "[detect-dupe] collect remote duplicate keys") + defer func() { + logger.End(zap.ErrorLevel, err) + }() - duplicateManager, err := NewDuplicateManager(local, ts, opts) + atomicHasDupe := atomic.NewBool(false) + duplicateManager, err := NewDuplicateManager(tbl, tableName, local.splitCli, local.tikvCli, + local.errorMgr, opts, local.dupeConcurrency, atomicHasDupe) if err != nil { - return false, errors.Annotate(err, "open duplicatemanager failed") + return false, errors.Trace(err) } - hasDupe, err := duplicateManager.CollectDuplicateRowsFromTiKV(ctx, tbl, tableName) - if err != nil { - return false, errors.Annotate(err, "collect remote duplicate rows failed") + if err := duplicateManager.CollectDuplicateRowsFromTiKV(ctx, local.importClientFactory); err != nil { + return false, errors.Trace(err) } - return hasDupe, nil + return atomicHasDupe.Load(), nil } func (local *local) ResolveDuplicateRows(ctx context.Context, tbl table.Table, tableName string, algorithm config.DuplicateResolutionAlgorithm) (err error) { @@ -1427,21 +1434,29 @@ func (local *local) ResolveDuplicateRows(ctx context.Context, tbl table.Table, t return err } - preRowID := int64(0) - for { - handleRows, lastRowID, err := local.errorMgr.GetConflictKeys(ctx, tableName, preRowID, 1000) - if err != nil { - return errors.Annotate(err, "cannot query conflict keys") - } - if len(handleRows) == 0 { - break - } - if err := local.deleteDuplicateRows(ctx, logger, handleRows, decoder); err != nil { - return errors.Annotate(err, "cannot delete duplicated entries") - } - preRowID = lastRowID - } - return nil + errLimiter := rate.NewLimiter(1, 1) + pool := utils.NewWorkerPool(uint(local.dupeConcurrency), "resolve duplicate rows") + err = local.errorMgr.ResolveAllConflictKeys( + ctx, tableName, pool, + func(ctx context.Context, handleRows [][2][]byte) error { + for { + err := local.deleteDuplicateRows(ctx, logger, handleRows, decoder) + if err == nil { + return nil + } + if log.IsContextCanceledError(err) { + return err + } + if !tikverror.IsErrWriteConflict(errors.Cause(err)) { + logger.Warn("delete duplicate rows encounter error", log.ShortError(err)) + } + if err = errLimiter.Wait(ctx); err != nil { + return err + } + } + }, + ) + return errors.Trace(err) } func (local *local) deleteDuplicateRows(ctx context.Context, logger *log.Task, handleRows [][2][]byte, decoder *kv.TableKVDecoder) (err error) { @@ -1450,7 +1465,6 @@ func (local *local) deleteDuplicateRows(ctx context.Context, logger *log.Task, h if err != nil { return err } - txn.SetPessimistic(true) defer func() { if err == nil { err = txn.Commit(ctx) @@ -1478,7 +1492,7 @@ func (local *local) deleteDuplicateRows(ctx context.Context, logger *log.Task, h return err } - handle, err := decoder.DecodeHandleFromTable(handleRow[0]) + handle, err := decoder.DecodeHandleFromRowKey(handleRow[0]) if err != nil { return err } @@ -1489,7 +1503,7 @@ func (local *local) deleteDuplicateRows(ctx context.Context, logger *log.Task, h } } - logger.Info("[resolve-dupe] number of KV pairs to be deleted", zap.Int("count", txn.Len())) + logger.Debug("[resolve-dupe] number of KV pairs to be deleted", zap.Int("count", txn.Len())) return nil } @@ -1509,15 +1523,8 @@ func (local *local) ResetEngine(ctx context.Context, engineUUID uuid.UUID) error } db, err := local.openEngineDB(engineUUID, false) if err == nil { - // Reset engineMeta except `Duplicates`. - meta := engineMeta{ - Duplicates: *atomic.NewInt64(localEngine.engineMeta.Duplicates.Load()), - } - if err := saveEngineMetaToDB(&meta, db); err != nil { - return errors.Trace(err) - } localEngine.db = db - localEngine.engineMeta = meta + localEngine.engineMeta = engineMeta{} if !common.IsDirExists(localEngine.sstDir) { if err := os.Mkdir(localEngine.sstDir, 0o755); err != nil { return errors.Trace(err) diff --git a/br/pkg/lightning/backend/local/local_test.go b/br/pkg/lightning/backend/local/local_test.go index 35c13692dce3e..50e3ab5d4503c 100644 --- a/br/pkg/lightning/backend/local/local_test.go +++ b/br/pkg/lightning/backend/local/local_test.go @@ -168,14 +168,14 @@ func (s *localSuite) TestRangeProperties(c *C) { for _, p := range cases { v := make([]byte, p.vLen) for i := 0; i < p.count; i++ { - _ = collector.Add(pebble.InternalKey{UserKey: p.key}, v) + _ = collector.Add(pebble.InternalKey{UserKey: p.key, Trailer: pebble.InternalKeyKindSet}, v) } } userProperties := make(map[string]string, 1) _ = collector.Finish(userProperties) - props, err := decodeRangeProperties(hack.Slice(userProperties[propRangeIndex])) + props, err := decodeRangeProperties(hack.Slice(userProperties[propRangeIndex]), noopKeyAdapter{}) c.Assert(err, IsNil) // Smallest key in props. @@ -302,7 +302,7 @@ func (s *localSuite) TestRangePropertiesWithPebble(c *C) { binary.BigEndian.PutUint64(key, uint64(i*100+j)) err = wb.Set(key, value[:valueLen], writeOpt) c.Assert(err, IsNil) - err = collector.Add(pebble.InternalKey{UserKey: key}, value[:valueLen]) + err = collector.Add(pebble.InternalKey{UserKey: key, Trailer: pebble.InternalKeyKindSet}, value[:valueLen]) c.Assert(err, IsNil) } c.Assert(wb.Commit(writeOpt), IsNil) @@ -900,8 +900,7 @@ func (e mockGrpcErr) Error() string { type mockImportClient struct { sst.ImportSSTClient - stores []*metapb.Store - curStore *metapb.Store + store *metapb.Store err error retry int cnt int @@ -916,32 +915,31 @@ func (c *mockImportClient) MultiIngest(context.Context, *sst.MultiIngestRequest, return nil, c.err } - if !c.multiIngestCheckFn(c.curStore) { + if !c.multiIngestCheckFn(c.store) { return nil, mockGrpcErr{} } return nil, nil } -type testMultiIngestSuite struct { - local *local - pdCli *mockPdClient +type mockImportClientFactory struct { + stores []*metapb.Store + createClientFn func(store *metapb.Store) sst.ImportSSTClient } -func (s *testMultiIngestSuite) SetUpSuite(c *C) { - local := &local{ - pdCtl: &pdutil.PdController{}, +func (f *mockImportClientFactory) Create(_ context.Context, storeID uint64) (sst.ImportSSTClient, error) { + for _, store := range f.stores { + if store.Id == storeID { + return f.createClientFn(store), nil + } } - pdCli := &mockPdClient{} - local.pdCtl.SetPDClient(pdCli) - s.local = local - s.pdCli = pdCli + return nil, errors.New("store not found") } -func (s *testMultiIngestSuite) TestMultiIngest(c *C) { - defer func() { - getImportClientFn = getImportClient - }() +func (f *mockImportClientFactory) Close() {} + +type testMultiIngestSuite struct{} +func (s *testMultiIngestSuite) TestMultiIngest(c *C) { allStores := []*metapb.Store{ { Id: 1, @@ -1191,30 +1189,29 @@ func (s *testMultiIngestSuite) TestMultiIngest(c *C) { } importCli := &mockImportClient{ - stores: allStores, cnt: 0, retry: testCase.retry, err: testCase.err, multiIngestCheckFn: testCase.multiIngestSupport, } - s.pdCli.stores = stores - - getImportClientFn = func(local *local, ctx context.Context, storeID uint64) (sst.ImportSSTClient, error) { - for _, store := range importCli.stores { - if store.Id == storeID { - importCli.curStore = store - break - } - } - return importCli, nil + pdCtl := &pdutil.PdController{} + pdCtl.SetPDClient(&mockPdClient{stores: stores}) + + local := &local{ + pdCtl: pdCtl, + importClientFactory: &mockImportClientFactory{ + stores: allStores, + createClientFn: func(store *metapb.Store) sst.ImportSSTClient { + importCli.store = store + return importCli + }, + }, } - s.local.supportMultiIngest = false - - err := s.local.checkMultiIngestSupport(context.Background()) + err := local.checkMultiIngestSupport(context.Background()) if err != nil { c.Assert(err, ErrorMatches, testCase.retErr) } else { - c.Assert(s.local.supportMultiIngest, Equals, testCase.supportMutliIngest) + c.Assert(local.supportMultiIngest, Equals, testCase.supportMutliIngest) } } } diff --git a/br/pkg/lightning/common/conn.go b/br/pkg/lightning/common/conn.go index 0dc011e88b7fa..83e6d3307412d 100644 --- a/br/pkg/lightning/common/conn.go +++ b/br/pkg/lightning/common/conn.go @@ -24,7 +24,7 @@ import ( "google.golang.org/grpc" ) -// connPool is a lazy pool of gRPC channels. +// ConnPool is a lazy pool of gRPC channels. // When `Get` called, it lazily allocates new connection if connection not full. // If it's full, then it will return allocated channels round-robin. type ConnPool struct { @@ -71,7 +71,7 @@ func (p *ConnPool) get(ctx context.Context) (*grpc.ClientConn, error) { return conn, nil } -// newConnPool creates a new connPool by the specified conn factory function and capacity. +// NewConnPool creates a new connPool by the specified conn factory function and capacity. func NewConnPool(cap int, newConn func(ctx context.Context) (*grpc.ClientConn, error)) *ConnPool { return &ConnPool{ cap: cap, @@ -105,7 +105,7 @@ func (conns *GRPCConns) GetGrpcConn(ctx context.Context, storeID uint64, tcpConc return conns.conns[storeID].get(ctx) } -func NewGRPCConns() GRPCConns { - cons := GRPCConns{conns: make(map[uint64]*ConnPool)} - return cons +func NewGRPCConns() *GRPCConns { + conns := &GRPCConns{conns: make(map[uint64]*ConnPool)} + return conns } diff --git a/br/pkg/lightning/common/util.go b/br/pkg/lightning/common/util.go index 9c2e6e09186ba..c24ee74fe37a2 100644 --- a/br/pkg/lightning/common/util.go +++ b/br/pkg/lightning/common/util.go @@ -58,7 +58,7 @@ func (param *MySQLConnectParam) ToDSN() string { param.SQLMode, param.MaxAllowedPacket, param.TLS) for k, v := range param.Vars { - dsn += fmt.Sprintf("&%s=%s", k, url.QueryEscape(v)) + dsn += fmt.Sprintf("&%s='%s'", k, url.QueryEscape(v)) } return dsn diff --git a/br/pkg/lightning/common/util_test.go b/br/pkg/lightning/common/util_test.go index 60812841ff259..3915a318b06c5 100644 --- a/br/pkg/lightning/common/util_test.go +++ b/br/pkg/lightning/common/util_test.go @@ -91,7 +91,7 @@ func (s *utilSuite) TestToDSN(c *C) { "tidb_distsql_scan_concurrency": "1", }, } - c.Assert(param.ToDSN(), Equals, "root:123456@tcp(127.0.0.1:4000)/?charset=utf8mb4&sql_mode='strict'&maxAllowedPacket=1234&tls=cluster&tidb_distsql_scan_concurrency=1") + c.Assert(param.ToDSN(), Equals, "root:123456@tcp(127.0.0.1:4000)/?charset=utf8mb4&sql_mode='strict'&maxAllowedPacket=1234&tls=cluster&tidb_distsql_scan_concurrency='1'") } func (s *utilSuite) TestIsContextCanceledError(c *C) { diff --git a/br/pkg/lightning/errormanager/errormanager.go b/br/pkg/lightning/errormanager/errormanager.go index 4d1e7c0fedf60..0b6edefaa433c 100644 --- a/br/pkg/lightning/errormanager/errormanager.go +++ b/br/pkg/lightning/errormanager/errormanager.go @@ -18,15 +18,19 @@ import ( "context" "database/sql" "fmt" + "math" "strings" + "sync" "github.com/pingcap/errors" "github.com/pingcap/tidb/br/pkg/lightning/common" "github.com/pingcap/tidb/br/pkg/lightning/config" "github.com/pingcap/tidb/br/pkg/lightning/log" "github.com/pingcap/tidb/br/pkg/redact" + "github.com/pingcap/tidb/br/pkg/utils" "go.uber.org/multierr" "go.uber.org/zap" + "golang.org/x/sync/errgroup" ) const ( @@ -87,19 +91,23 @@ const ( insertIntoConflictErrorData = ` INSERT INTO %s.` + conflictErrorTableName + ` (task_id, table_name, index_name, key_data, row_data, raw_key, raw_value, raw_handle, raw_row) - VALUES (?, ?, 'PRIMARY', ?, ?, ?, ?, raw_key, raw_value); + VALUES ` + sqlValuesConflictErrorData = "(?,?,'PRIMARY',?,?,?,?,raw_key,raw_value)" + insertIntoConflictErrorIndex = ` INSERT INTO %s.` + conflictErrorTableName + ` (task_id, table_name, index_name, key_data, row_data, raw_key, raw_value, raw_handle, raw_row) - VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?); + VALUES ` + sqlValuesConflictErrorIndex = "(?,?,?,?,?,?,?,?,?)" + selectConflictKeys = ` SELECT _tidb_rowid, raw_handle, raw_row FROM %s.` + conflictErrorTableName + ` - WHERE table_name = ? AND _tidb_rowid > ? + WHERE table_name = ? AND _tidb_rowid >= ? and _tidb_rowid < ? ORDER BY _tidb_rowid LIMIT ?; ` ) @@ -224,6 +232,9 @@ func (em *ErrorManager) RecordDataConflictError( if em.db == nil { return nil } + if len(conflictInfos) == 0 { + return nil + } exec := common.SQLWithRetry{ DB: em.db, @@ -231,13 +242,15 @@ func (em *ErrorManager) RecordDataConflictError( HideQueryLog: redact.NeedRedact(), } return exec.Transact(ctx, "insert data conflict error record", func(c context.Context, txn *sql.Tx) error { - stmt, err := txn.PrepareContext(c, fmt.Sprintf(insertIntoConflictErrorData, em.schemaEscaped)) - if err != nil { - return err - } - defer stmt.Close() - for _, conflictInfo := range conflictInfos { - _, err = stmt.ExecContext(c, + sb := &strings.Builder{} + fmt.Fprintf(sb, insertIntoConflictErrorData, em.schemaEscaped) + var sqlArgs []interface{} + for i, conflictInfo := range conflictInfos { + if i > 0 { + sb.WriteByte(',') + } + sb.WriteString(sqlValuesConflictErrorData) + sqlArgs = append(sqlArgs, em.taskID, tableName, conflictInfo.KeyData, @@ -245,11 +258,9 @@ func (em *ErrorManager) RecordDataConflictError( conflictInfo.RawKey, conflictInfo.RawValue, ) - if err != nil { - return err - } } - return nil + _, err := txn.ExecContext(c, sb.String(), sqlArgs...) + return err }) } @@ -264,6 +275,9 @@ func (em *ErrorManager) RecordIndexConflictError( if em.db == nil { return nil } + if len(conflictInfos) == 0 { + return nil + } exec := common.SQLWithRetry{ DB: em.db, @@ -271,13 +285,15 @@ func (em *ErrorManager) RecordIndexConflictError( HideQueryLog: redact.NeedRedact(), } return exec.Transact(ctx, "insert index conflict error record", func(c context.Context, txn *sql.Tx) error { - stmt, err := txn.PrepareContext(c, fmt.Sprintf(insertIntoConflictErrorIndex, em.schemaEscaped)) - if err != nil { - return err - } - defer stmt.Close() + sb := &strings.Builder{} + fmt.Fprintf(sb, insertIntoConflictErrorIndex, em.schemaEscaped) + var sqlArgs []interface{} for i, conflictInfo := range conflictInfos { - _, err = stmt.ExecContext(c, + if i > 0 { + sb.WriteByte(',') + } + sb.WriteString(sqlValuesConflictErrorIndex) + sqlArgs = append(sqlArgs, em.taskID, tableName, indexNames[i], @@ -288,38 +304,87 @@ func (em *ErrorManager) RecordIndexConflictError( rawHandles[i], rawRows[i], ) - if err != nil { - return err - } } - return nil + _, err := txn.ExecContext(c, sb.String(), sqlArgs...) + return err }) } -// GetConflictKeys obtains all (distinct) conflicting rows (handle and their -// values) from the current error report. -func (em *ErrorManager) GetConflictKeys(ctx context.Context, tableName string, prevRowID int64, limit int) (handleRows [][2][]byte, lastRowID int64, err error) { +// ResolveAllConflictKeys query all conflicting rows (handle and their +// values) from the current error report and resolve them concurrently. +func (em *ErrorManager) ResolveAllConflictKeys( + ctx context.Context, + tableName string, + pool *utils.WorkerPool, + fn func(ctx context.Context, handleRows [][2][]byte) error, +) error { if em.db == nil { - return nil, 0, nil - } - rows, err := em.db.QueryContext( - ctx, - fmt.Sprintf(selectConflictKeys, em.schemaEscaped), - tableName, - prevRowID, - limit, - ) - if err != nil { - return nil, 0, errors.Trace(err) + return nil } - defer rows.Close() - for rows.Next() { - var handleRow [2][]byte - if err := rows.Scan(&lastRowID, &handleRow[0], &handleRow[1]); err != nil { - return nil, 0, errors.Trace(err) - } - handleRows = append(handleRows, handleRow) + const rowLimit = 1000 + taskCh := make(chan [2]int64) + taskWg := &sync.WaitGroup{} + g, gCtx := errgroup.WithContext(ctx) + + go func() { + //nolint:staticcheck + taskWg.Add(1) + taskCh <- [2]int64{0, math.MaxInt64} + taskWg.Wait() + close(taskCh) + }() + + for t := range taskCh { + start, end := t[0], t[1] + pool.ApplyOnErrorGroup(g, func() error { + defer taskWg.Done() + + var handleRows [][2][]byte + for start < end { + rows, err := em.db.QueryContext( + gCtx, fmt.Sprintf(selectConflictKeys, em.schemaEscaped), + tableName, start, end, rowLimit) + if err != nil { + return errors.Trace(err) + } + var lastRowID int64 + for rows.Next() { + var handleRow [2][]byte + if err := rows.Scan(&lastRowID, &handleRow[0], &handleRow[1]); err != nil { + return errors.Trace(err) + } + handleRows = append(handleRows, handleRow) + } + if err := rows.Err(); err != nil { + return errors.Trace(err) + } + if err := rows.Close(); err != nil { + return errors.Trace(err) + } + if len(handleRows) == 0 { + break + } + if err := fn(gCtx, handleRows); err != nil { + return errors.Trace(err) + } + start = lastRowID + 1 + // If the remaining tasks cannot be processed at once, split the task + // into two subtasks and send one of them to the other idle worker if possible. + if end-start > rowLimit { + mid := start + (end-start)/2 + taskWg.Add(1) + select { + case taskCh <- [2]int64{mid, end}: + end = mid + default: + taskWg.Done() + } + } + handleRows = handleRows[:0] + } + return nil + }) } - return handleRows, lastRowID, errors.Trace(rows.Err()) + return errors.Trace(g.Wait()) } diff --git a/br/pkg/lightning/errormanager/errormanager_test.go b/br/pkg/lightning/errormanager/errormanager_test.go index 2b5aba0e07605..87e31e57b54c1 100644 --- a/br/pkg/lightning/errormanager/errormanager_test.go +++ b/br/pkg/lightning/errormanager/errormanager_test.go @@ -16,24 +16,23 @@ package errormanager import ( "context" + "database/sql" + "database/sql/driver" + "io" + "math/rand" + "strconv" "testing" "github.com/DATA-DOG/go-sqlmock" - . "github.com/pingcap/check" "github.com/pingcap/tidb/br/pkg/lightning/config" + "github.com/pingcap/tidb/br/pkg/utils" + "github.com/stretchr/testify/require" + "go.uber.org/atomic" ) -var _ = Suite(errorManagerSuite{}) - -func TestErrorManager(t *testing.T) { - TestingT(t) -} - -type errorManagerSuite struct{} - -func (e errorManagerSuite) TestInit(c *C) { +func TestInit(t *testing.T) { db, mock, err := sqlmock.New() - c.Assert(err, IsNil) + require.NoError(t, err) cfg := config.NewConfig() cfg.TikvImporter.DuplicateResolution = config.DupeResAlgRecord @@ -41,15 +40,15 @@ func (e errorManagerSuite) TestInit(c *C) { cfg.App.TaskInfoSchemaName = "lightning_errors" em := New(db, cfg) - c.Assert(em.dupResolution, Equals, cfg.TikvImporter.DuplicateResolution) - c.Assert(em.remainingError.Type.Load(), Equals, cfg.App.MaxError.Type.Load()) - c.Assert(em.remainingError.Conflict.Load(), Equals, cfg.App.MaxError.Conflict.Load()) + require.Equal(t, cfg.TikvImporter.DuplicateResolution, em.dupResolution) + require.Equal(t, cfg.App.MaxError.Type.Load(), em.remainingError.Type.Load()) + require.Equal(t, cfg.App.MaxError.Conflict.Load(), em.remainingError.Conflict.Load()) em.remainingError.Type.Store(0) em.dupResolution = config.DupeResAlgNone ctx := context.Background() err = em.Init(ctx) - c.Assert(err, IsNil) + require.NoError(t, err) em.dupResolution = config.DupeResAlgRecord mock.ExpectExec("CREATE SCHEMA IF NOT EXISTS `lightning_errors`;"). @@ -57,7 +56,7 @@ func (e errorManagerSuite) TestInit(c *C) { mock.ExpectExec("CREATE TABLE IF NOT EXISTS `lightning_errors`\\.conflict_error_v1.*"). WillReturnResult(sqlmock.NewResult(2, 1)) err = em.Init(ctx) - c.Assert(err, IsNil) + require.NoError(t, err) em.dupResolution = config.DupeResAlgNone em.remainingError.Type.Store(1) @@ -66,8 +65,7 @@ func (e errorManagerSuite) TestInit(c *C) { mock.ExpectExec("CREATE TABLE IF NOT EXISTS `lightning_errors`\\.type_error_v1.*"). WillReturnResult(sqlmock.NewResult(4, 1)) err = em.Init(ctx) - c.Assert(err, IsNil) - + require.NoError(t, err) em.dupResolution = config.DupeResAlgRecord em.remainingError.Type.Store(1) mock.ExpectExec("CREATE SCHEMA IF NOT EXISTS `lightning_errors`.*"). @@ -77,7 +75,105 @@ func (e errorManagerSuite) TestInit(c *C) { mock.ExpectExec("CREATE TABLE IF NOT EXISTS `lightning_errors`\\.conflict_error_v1.*"). WillReturnResult(sqlmock.NewResult(7, 1)) err = em.Init(ctx) - c.Assert(err, IsNil) + require.NoError(t, err) + + require.NoError(t, mock.ExpectationsWereMet()) +} + +type mockDriver struct { + driver.Driver + totalRows int64 +} + +func (m mockDriver) Open(_ string) (driver.Conn, error) { + return mockConn{totalRows: m.totalRows}, nil +} + +type mockConn struct { + driver.Conn + driver.ExecerContext + driver.QueryerContext + totalRows int64 +} + +func (c mockConn) ExecContext(_ context.Context, _ string, _ []driver.NamedValue) (driver.Result, error) { + return sqlmock.NewResult(1, 1), nil +} + +func (mockConn) Close() error { return nil } + +type mockRows struct { + driver.Rows + start int64 + end int64 +} + +func (r *mockRows) Columns() []string { + return []string{"_tidb_rowid", "raw_handle", "raw_row"} +} + +func (r *mockRows) Close() error { return nil } + +func (r *mockRows) Next(dest []driver.Value) error { + if r.start >= r.end { + return io.EOF + } + dest[0] = r.start // _tidb_rowid + dest[1] = []byte{} // raw_handle + dest[2] = []byte{} // raw_row + r.start++ + return nil +} + +func (c mockConn) QueryContext(_ context.Context, query string, args []driver.NamedValue) (driver.Rows, error) { + expectedQuery := "SELECT _tidb_rowid, raw_handle, raw_row.*" + if err := sqlmock.QueryMatcherRegexp.Match(expectedQuery, query); err != nil { + return &mockRows{}, nil + } + if len(args) != 4 { + return &mockRows{}, nil + } + // args are tableName, start, end, and limit. + start := args[1].Value.(int64) + if start < 1 { + start = 1 + } + end := args[2].Value.(int64) + if end > c.totalRows+1 { + end = c.totalRows + 1 + } + limit := args[3].Value.(int64) + if start+limit < end { + end = start + limit + } + return &mockRows{start: start, end: end}, nil +} + +func TestResolveAllConflictKeys(t *testing.T) { + const totalRows = int64(1 << 18) + driverName := "errmgr-mock-" + strconv.Itoa(rand.Int()) + sql.Register(driverName, mockDriver{totalRows: totalRows}) + db, err := sql.Open(driverName, "") + require.NoError(t, err) + defer db.Close() + + cfg := config.NewConfig() + cfg.TikvImporter.DuplicateResolution = config.DupeResAlgRemove + cfg.App.TaskInfoSchemaName = "lightning_errors" + em := New(db, cfg) + ctx := context.Background() + err = em.Init(ctx) + require.NoError(t, err) - c.Assert(mock.ExpectationsWereMet(), IsNil) + resolved := atomic.NewInt64(0) + pool := utils.NewWorkerPool(16, "resolve duplicate rows") + err = em.ResolveAllConflictKeys( + ctx, "test", pool, + func(ctx context.Context, handleRows [][2][]byte) error { + resolved.Add(int64(len(handleRows))) + return nil + }, + ) + require.NoError(t, err) + require.Equal(t, totalRows, resolved.Load()) } diff --git a/br/pkg/lightning/lightning.go b/br/pkg/lightning/lightning.go index f067a535faa85..02a8ec80b1f63 100644 --- a/br/pkg/lightning/lightning.go +++ b/br/pkg/lightning/lightning.go @@ -216,7 +216,7 @@ func (l *Lightning) RunServer() error { return err } err = l.run(context.Background(), task, nil) - if err != nil { + if err != nil && !common.IsContextCanceledError(err) { restore.DeliverPauser.Pause() // force pause the progress on error log.L().Error("tidb lightning encountered error", zap.Error(err)) } diff --git a/br/pkg/lightning/restore/check_template.go b/br/pkg/lightning/restore/check_template.go index f38e23aa00f8e..f02410116b3dd 100644 --- a/br/pkg/lightning/restore/check_template.go +++ b/br/pkg/lightning/restore/check_template.go @@ -15,7 +15,6 @@ package restore import ( - "fmt" "strings" "github.com/jedib0t/go-pretty/v6/table" @@ -124,17 +123,5 @@ func (c *SimpleTemplate) Output() string { } return nil }) - res := c.t.Render() - summary := "\n" - if c.criticalFailedCount > 0 { - summary += fmt.Sprintf("%d critical check failed", c.criticalFailedCount) - } - if c.warnFailedCount > 0 { - msg := fmt.Sprintf("%d performance check failed", c.warnFailedCount) - if len(summary) > 1 { - msg = "," + msg - } - summary += msg - } - return res + summary + return c.t.Render() + "\n" } diff --git a/br/pkg/lightning/restore/restore.go b/br/pkg/lightning/restore/restore.go index 34b01288210cc..d347b7cc324f6 100644 --- a/br/pkg/lightning/restore/restore.go +++ b/br/pkg/lightning/restore/restore.go @@ -457,7 +457,6 @@ outside: case err == nil: case log.IsContextCanceledError(err): logger.Info("task canceled") - err = nil break outside default: logger.Error("run failed") @@ -1337,7 +1336,7 @@ func (rc *Controller) keepPauseGCForDupeRes(ctx context.Context) (<-chan struct{ return exitCh, nil } -func (rc *Controller) restoreTables(ctx context.Context) error { +func (rc *Controller) restoreTables(ctx context.Context) (finalErr error) { if rc.cfg.TikvImporter.DuplicateResolution != config.DupeResAlgNone { subCtx, cancel := context.WithCancel(ctx) exitCh, err := rc.keepPauseGCForDupeRes(subCtx) @@ -1366,16 +1365,21 @@ func (rc *Controller) restoreTables(ctx context.Context) error { finishSchedulers := func() {} // if one lightning failed abnormally, and can't determine whether it needs to switch back, // we do not do switch back automatically - cleanupFunc := func() {} switchBack := false - taskFinished := false + cleanup := false + postProgress := func() error { return nil } if rc.cfg.TikvImporter.Backend == config.BackendLocal { logTask.Info("removing PD leader®ion schedulers") restoreFn, err := rc.taskMgr.CheckAndPausePdSchedulers(ctx) + if err != nil { + return errors.Trace(err) + } + finishSchedulers = func() { if restoreFn != nil { + taskFinished := finalErr == nil // use context.Background to make sure this restore function can still be executed even if ctx is canceled restoreCtx := context.Background() needSwitchBack, needCleanup, err := rc.taskMgr.CheckAndFinishRestore(restoreCtx, taskFinished) @@ -1385,39 +1389,17 @@ func (rc *Controller) restoreTables(ctx context.Context) error { } switchBack = needSwitchBack if needSwitchBack { + logTask.Info("add back PD leader®ion schedulers") if restoreE := restoreFn(restoreCtx); restoreE != nil { logTask.Warn("failed to restore removed schedulers, you may need to restore them manually", zap.Error(restoreE)) } - - logTask.Info("add back PD leader®ion schedulers") - // clean up task metas - if needCleanup { - logTask.Info("cleanup task metas") - if cleanupErr := rc.taskMgr.Cleanup(restoreCtx); cleanupErr != nil { - logTask.Warn("failed to clean task metas, you may need to restore them manually", zap.Error(cleanupErr)) - } - // cleanup table meta and schema db if needed. - cleanupFunc = func() { - if e := rc.taskMgr.CleanupAllMetas(restoreCtx); err != nil { - logTask.Warn("failed to clean table task metas, you may need to restore them manually", zap.Error(e)) - } - } - } } + cleanup = needCleanup } rc.taskMgr.Close() } - - if err != nil { - return errors.Trace(err) - } } - defer func() { - if switchBack { - cleanupFunc() - } - }() type task struct { tr *TableRestore @@ -1437,17 +1419,31 @@ func (rc *Controller) restoreTables(ctx context.Context) error { periodicActions, cancelFunc := rc.buildRunPeriodicActionAndCancelFunc(ctx, stopPeriodicActions) go periodicActions() - finishFuncCalled := false + + defer close(stopPeriodicActions) + defer func() { - if !finishFuncCalled { - finishSchedulers() - cancelFunc(switchBack) - finishFuncCalled = true + finishSchedulers() + cancelFunc(switchBack) + + if err := postProgress(); err != nil { + logTask.End(zap.ErrorLevel, err) + finalErr = err + return + } + // clean up task metas + if cleanup { + logTask.Info("cleanup task metas") + if cleanupErr := rc.taskMgr.Cleanup(context.Background()); cleanupErr != nil { + logTask.Warn("failed to clean task metas, you may need to restore them manually", zap.Error(cleanupErr)) + } + // cleanup table meta and schema db if needed. + if err := rc.taskMgr.CleanupAllMetas(context.Background()); err != nil { + logTask.Warn("failed to clean table task metas, you may need to restore them manually", zap.Error(err)) + } } }() - defer close(stopPeriodicActions) - taskCh := make(chan task, rc.cfg.App.IndexConcurrency) defer close(taskCh) @@ -1515,32 +1511,26 @@ func (rc *Controller) restoreTables(ctx context.Context) error { default: } - // stop periodic tasks for restore table such as pd schedulers and switch-mode tasks. - // this can help make cluster switching back to normal state more quickly. - // finishSchedulers() - // cancelFunc(switchBack) - // finishFuncCalled = true - taskFinished = true - - close(postProcessTaskChan) - // otherwise, we should run all tasks in the post-process task chan - for i := 0; i < rc.cfg.App.TableConcurrency; i++ { - wg.Add(1) - go func() { - defer wg.Done() - for task := range postProcessTaskChan { - metaMgr := rc.metaMgrBuilder.TableMetaMgr(task.tr) - // force all the remain post-process tasks to be executed - _, err = task.tr.postProcess(ctx2, rc, task.cp, true, metaMgr) - restoreErr.Set(err) - } - }() + postProgress = func() error { + close(postProcessTaskChan) + // otherwise, we should run all tasks in the post-process task chan + for i := 0; i < rc.cfg.App.TableConcurrency; i++ { + wg.Add(1) + go func() { + defer wg.Done() + for task := range postProcessTaskChan { + metaMgr := rc.metaMgrBuilder.TableMetaMgr(task.tr) + // force all the remain post-process tasks to be executed + _, err = task.tr.postProcess(ctx2, rc, task.cp, true, metaMgr) + restoreErr.Set(err) + } + }() + } + wg.Wait() + return restoreErr.Get() } - wg.Wait() - err = restoreErr.Get() - logTask.End(zap.ErrorLevel, err) - return err + return nil } func (tr *TableRestore) restoreTable( diff --git a/br/pkg/lightning/restore/table_restore.go b/br/pkg/lightning/restore/table_restore.go index 2faf361bf3a3a..e340063d8a34b 100644 --- a/br/pkg/lightning/restore/table_restore.go +++ b/br/pkg/lightning/restore/table_restore.go @@ -812,7 +812,7 @@ func (tr *TableRestore) postProcess( } // Don't call FinishTable when other lightning will calculate checksum. - if err == nil && !hasDupe && needChecksum { + if err == nil && needChecksum { err = metaMgr.FinishTable(ctx) } diff --git a/br/tests/lightning_character_sets/run.sh b/br/tests/lightning_character_sets/run.sh index 6750ee08a4838..3d62ed26360ca 100755 --- a/br/tests/lightning_character_sets/run.sh +++ b/br/tests/lightning_character_sets/run.sh @@ -41,7 +41,7 @@ run_sql 'DROP TABLE charsets.gb18030;' run_lightning_expecting_fail --config "tests/$TEST_NAME/utf8mb4.toml" -d "tests/$TEST_NAME/gb18030" run_lightning --config "tests/$TEST_NAME/binary.toml" -d "tests/$TEST_NAME/gb18030" -run_sql 'SELECT sum(`Ö÷¼ü`) AS s FROM charsets.gb18030' +run_sql 'SELECT sum(`????`) AS s FROM charsets.gb18030' check_contains 's: 267' # utf8mb4 diff --git a/cmd/explaintest/r/index_merge.result b/cmd/explaintest/r/index_merge.result index f4d6571fd8fe3..ddce511db13c8 100644 --- a/cmd/explaintest/r/index_merge.result +++ b/cmd/explaintest/r/index_merge.result @@ -722,11 +722,11 @@ c1 c2 c3 c4 c5 explain select /*+ use_index_merge(t1) */ * from t1 where (c1 < 10 or c2 < 10) and substring(c3, 1, 1) = '1' order by 1; id estRows task access object operator info Sort_5 4433.77 root test.t1.c1 -└─Selection_12 4433.77 root eq(substring(cast(test.t1.c3, var_string(20)), 1, 1), "1") - └─IndexMerge_11 5542.21 root - ├─IndexRangeScan_8(Build) 3323.33 cop[tikv] table:t1, index:c1(c1) range:[-inf,10), keep order:false, stats:pseudo - ├─IndexRangeScan_9(Build) 3323.33 cop[tikv] table:t1, index:c2(c2) range:[-inf,10), keep order:false, stats:pseudo - └─TableRowIDScan_10(Probe) 5542.21 cop[tikv] table:t1 keep order:false, stats:pseudo +└─IndexMerge_12 4433.77 root + ├─IndexRangeScan_8(Build) 3323.33 cop[tikv] table:t1, index:c1(c1) range:[-inf,10), keep order:false, stats:pseudo + ├─IndexRangeScan_9(Build) 3323.33 cop[tikv] table:t1, index:c2(c2) range:[-inf,10), keep order:false, stats:pseudo + └─Selection_11(Probe) 4433.77 cop[tikv] eq(substring(cast(test.t1.c3, var_string(20)), 1, 1), "1") + └─TableRowIDScan_10 5542.21 cop[tikv] table:t1 keep order:false, stats:pseudo select /*+ use_index_merge(t1) */ * from t1 where (c1 < 10 or c2 < 10) and substring(c3, 1, 1) = '1' order by 1; c1 c2 c3 c4 c5 1 1 1 1 1 diff --git a/cmd/explaintest/r/tpch.result b/cmd/explaintest/r/tpch.result index c9edd22189514..32afcf2ab4fad 100644 --- a/cmd/explaintest/r/tpch.result +++ b/cmd/explaintest/r/tpch.result @@ -1294,14 +1294,13 @@ cntrycode order by cntrycode; id estRows task access object operator info -Sort 1.00 root Column#27 -└─Projection 1.00 root Column#27, Column#28, Column#29 - └─HashAgg 1.00 root group by:Column#33, funcs:count(1)->Column#28, funcs:sum(Column#31)->Column#29, funcs:firstrow(Column#32)->Column#27 - └─Projection 0.00 root tpch.customer.c_acctbal, substring(tpch.customer.c_phone, 1, 2)->Column#32, substring(tpch.customer.c_phone, 1, 2)->Column#33 +Sort 1.00 root Column#31 +└─Projection 1.00 root Column#31, Column#32, Column#33 + └─HashAgg 1.00 root group by:Column#37, funcs:count(1)->Column#32, funcs:sum(Column#35)->Column#33, funcs:firstrow(Column#36)->Column#31 + └─Projection 0.00 root tpch.customer.c_acctbal, substring(tpch.customer.c_phone, 1, 2)->Column#36, substring(tpch.customer.c_phone, 1, 2)->Column#37 └─HashJoin 0.00 root anti semi join, equal:[eq(tpch.customer.c_custkey, tpch.orders.o_custkey)] ├─TableReader(Build) 75000000.00 root data:TableFullScan │ └─TableFullScan 75000000.00 cop[tikv] table:orders keep order:false - └─Selection(Probe) 0.00 root in(substring(tpch.customer.c_phone, 1, 2), "20", "40", "22", "30", "39", "42", "21") - └─TableReader 0.00 root data:Selection - └─Selection 0.00 cop[tikv] gt(tpch.customer.c_acctbal, NULL) - └─TableFullScan 7500000.00 cop[tikv] table:customer keep order:false + └─TableReader(Probe) 0.00 root data:Selection + └─Selection 0.00 cop[tikv] gt(tpch.customer.c_acctbal, NULL), in(substring(tpch.customer.c_phone, 1, 2), "20", "40", "22", "30", "39", "42", "21") + └─TableFullScan 7500000.00 cop[tikv] table:customer keep order:false diff --git a/config/config.go b/config/config.go index 2319d953286c7..1727fae20698f 100644 --- a/config/config.go +++ b/config/config.go @@ -65,6 +65,14 @@ const ( DefTableColumnCountLimit = 1017 // DefMaxOfTableColumnCountLimit is maximum limitation of the number of columns in a table DefMaxOfTableColumnCountLimit = 4096 + // DefStatsLoadConcurrencyLimit is limit of the concurrency of stats-load + DefStatsLoadConcurrencyLimit = 1 + // DefMaxOfStatsLoadConcurrencyLimit is maximum limitation of the concurrency of stats-load + DefMaxOfStatsLoadConcurrencyLimit = 128 + // DefStatsLoadQueueSizeLimit is limit of the size of stats-load request queue + DefStatsLoadQueueSizeLimit = 1 + // DefMaxOfStatsLoadQueueSizeLimit is maximum limitation of the size of stats-load request queue + DefMaxOfStatsLoadQueueSizeLimit = 100000 ) // Valid config maps @@ -483,11 +491,13 @@ type Performance struct { CommitterConcurrency int `toml:"committer-concurrency" json:"committer-concurrency"` MaxTxnTTL uint64 `toml:"max-txn-ttl" json:"max-txn-ttl"` // Deprecated - MemProfileInterval string `toml:"-" json:"-"` - IndexUsageSyncLease string `toml:"index-usage-sync-lease" json:"index-usage-sync-lease"` - PlanReplayerGCLease string `toml:"plan-replayer-gc-lease" json:"plan-replayer-gc-lease"` - GOGC int `toml:"gogc" json:"gogc"` - EnforceMPP bool `toml:"enforce-mpp" json:"enforce-mpp"` + MemProfileInterval string `toml:"-" json:"-"` + IndexUsageSyncLease string `toml:"index-usage-sync-lease" json:"index-usage-sync-lease"` + PlanReplayerGCLease string `toml:"plan-replayer-gc-lease" json:"plan-replayer-gc-lease"` + GOGC int `toml:"gogc" json:"gogc"` + EnforceMPP bool `toml:"enforce-mpp" json:"enforce-mpp"` + StatsLoadConcurrency uint `toml:"stats-load-concurrency" json:"stats-load-concurrency"` + StatsLoadQueueSize uint `toml:"stats-load-queue-size" json:"stats-load-queue-size"` } // PlanCache is the PlanCache section of the config. @@ -702,10 +712,12 @@ var defaultConf = Config{ CommitterConcurrency: defTiKVCfg.CommitterConcurrency, MaxTxnTTL: defTiKVCfg.MaxTxnTTL, // 1hour // TODO: set indexUsageSyncLease to 60s. - IndexUsageSyncLease: "0s", - GOGC: 100, - EnforceMPP: false, - PlanReplayerGCLease: "10m", + IndexUsageSyncLease: "0s", + GOGC: 100, + EnforceMPP: false, + PlanReplayerGCLease: "10m", + StatsLoadConcurrency: 5, + StatsLoadQueueSize: 1000, }, ProxyProtocol: ProxyProtocol{ Networks: "", @@ -1001,6 +1013,14 @@ func (c *Config) Valid() error { c.Security.SpilledFileEncryptionMethod, SpilledFileEncryptionMethodPlaintext, SpilledFileEncryptionMethodAES128CTR) } + // check stats load config + if c.Performance.StatsLoadConcurrency < DefStatsLoadConcurrencyLimit || c.Performance.StatsLoadConcurrency > DefMaxOfStatsLoadConcurrencyLimit { + return fmt.Errorf("stats-load-concurrency should be [%d, %d]", DefStatsLoadConcurrencyLimit, DefMaxOfStatsLoadConcurrencyLimit) + } + if c.Performance.StatsLoadQueueSize < DefStatsLoadQueueSizeLimit || c.Performance.StatsLoadQueueSize > DefMaxOfStatsLoadQueueSizeLimit { + return fmt.Errorf("stats-load-queue-size should be [%d, %d]", DefStatsLoadQueueSizeLimit, DefMaxOfStatsLoadQueueSizeLimit) + } + // test log level l := zap.NewAtomicLevel() return l.UnmarshalText([]byte(c.Log.Level)) diff --git a/config/config_test.go b/config/config_test.go index 25c79dc40ebe3..6515d76c56453 100644 --- a/config/config_test.go +++ b/config/config_test.go @@ -674,3 +674,24 @@ func TestConfigExample(t *testing.T) { } } } + +func TestStatsLoadLimit(t *testing.T) { + conf := NewConfig() + checkConcurrencyValid := func(concurrency int, shouldBeValid bool) { + conf.Performance.StatsLoadConcurrency = uint(concurrency) + require.Equal(t, shouldBeValid, conf.Valid() == nil) + } + checkConcurrencyValid(DefStatsLoadConcurrencyLimit, true) + checkConcurrencyValid(DefStatsLoadConcurrencyLimit-1, false) + checkConcurrencyValid(DefMaxOfStatsLoadConcurrencyLimit, true) + checkConcurrencyValid(DefMaxOfStatsLoadConcurrencyLimit+1, false) + conf = NewConfig() + checkQueueSizeValid := func(queueSize int, shouldBeValid bool) { + conf.Performance.StatsLoadQueueSize = uint(queueSize) + require.Equal(t, shouldBeValid, conf.Valid() == nil) + } + checkQueueSizeValid(DefStatsLoadQueueSizeLimit, true) + checkQueueSizeValid(DefStatsLoadQueueSizeLimit-1, false) + checkQueueSizeValid(DefMaxOfStatsLoadQueueSizeLimit, true) + checkQueueSizeValid(DefMaxOfStatsLoadQueueSizeLimit+1, false) +} diff --git a/ddl/column.go b/ddl/column.go index 0fc277e6ac78b..a93183d97f82a 100644 --- a/ddl/column.go +++ b/ddl/column.go @@ -1761,17 +1761,20 @@ func isColumnWithIndex(colName string, indices []*model.IndexInfo) bool { return false } -func isColumnCanDropWithIndex(isMultiSchemaChange bool, colName string, indices []*model.IndexInfo) bool { +func isColumnCanDropWithIndex(isMultiSchemaChange bool, colName string, indices []*model.IndexInfo) error { for _, indexInfo := range indices { - if indexInfo.Primary || len(indexInfo.Columns) > 1 || (!isMultiSchemaChange && len(indexInfo.Columns) == 1) { + if indexInfo.Primary || len(indexInfo.Columns) > 1 { for _, col := range indexInfo.Columns { if col.Name.L == colName { - return false + return errCantDropColWithIndex.GenWithStack("can't drop column %s with composite index covered or Primary Key covered now", colName) } } } + if len(indexInfo.Columns) == 1 && indexInfo.Columns[0].Name.L == colName && !isMultiSchemaChange { + return errCantDropColWithIndex.GenWithStack("can't drop column %s with tidb_enable_change_multi_schema is disable", colName) + } } - return true + return nil } func listIndicesWithColumn(colName string, indices []*model.IndexInfo) []*model.IndexInfo { diff --git a/ddl/ddl_api.go b/ddl/ddl_api.go index 1406cbc3bc9d1..1d5251d7e7e37 100644 --- a/ddl/ddl_api.go +++ b/ddl/ddl_api.go @@ -5795,8 +5795,9 @@ func isDroppableColumn(multiSchemaChange bool, tblInfo *model.TableInfo, colName colName, tblInfo.Name) } // We only support dropping column with single-value none Primary Key index covered now. - if !isColumnCanDropWithIndex(multiSchemaChange, colName.L, tblInfo.Indices) { - return errCantDropColWithIndex.GenWithStack("can't drop column %s with composite index covered or Primary Key covered now", colName) + err := isColumnCanDropWithIndex(multiSchemaChange, colName.L, tblInfo.Indices) + if err != nil { + return err } // Check the column with foreign key. if fkInfo := getColumnForeignKeyInfo(colName.L, tblInfo.ForeignKeys); fkInfo != nil { diff --git a/domain/domain.go b/domain/domain.go index d5de5bfab2182..58f46588ac0aa 100644 --- a/domain/domain.go +++ b/domain/domain.go @@ -1296,6 +1296,16 @@ func (do *Domain) UpdateTableStatsLoop(ctx sessionctx.Context) error { return nil } +// StartLoadStatsSubWorkers starts sub workers with new sessions to load stats concurrently +func (do *Domain) StartLoadStatsSubWorkers(ctxList []sessionctx.Context) { + statsHandle := do.StatsHandle() + for i, ctx := range ctxList { + statsHandle.StatsLoad.SubCtxs[i] = ctx + do.wg.Add(1) + go statsHandle.SubLoadWorker(ctx, do.exit, &do.wg) + } +} + func (do *Domain) newOwnerManager(prompt, ownerKey string) owner.Manager { id := do.ddl.OwnerManager().ID() var statsOwner owner.Manager diff --git a/domain/sysvar_cache.go b/domain/sysvar_cache.go index c4f28629fa332..d99de612dfcda 100644 --- a/domain/sysvar_cache.go +++ b/domain/sysvar_cache.go @@ -245,6 +245,15 @@ func (do *Domain) checkEnableServerGlobalVar(name, sVal string) { variable.PersistAnalyzeOptions.Store(variable.TiDBOptOn(sVal)) case variable.TiDBEnableColumnTracking: variable.EnableColumnTracking.Store(variable.TiDBOptOn(sVal)) + case variable.TiDBStatsLoadSyncWait: + var val int64 + val, err = strconv.ParseInt(sVal, 10, 64) + if err != nil { + break + } + variable.StatsLoadSyncWait.Store(val) + case variable.TiDBStatsLoadPseudoTimeout: + variable.StatsLoadPseudoTimeout.Store(variable.TiDBOptOn(sVal)) } if err != nil { logutil.BgLogger().Error(fmt.Sprintf("load global variable %s error", name), zap.Error(err)) diff --git a/executor/executor_test.go b/executor/executor_test.go index 454a503ba91c0..f75e7fa97ff6d 100644 --- a/executor/executor_test.go +++ b/executor/executor_test.go @@ -9452,6 +9452,19 @@ func (s *testSuiteWithData) TestPlanReplayerDumpSingle(c *C) { } } +func (s *testSuiteWithData) TestDropColWithPrimaryKey(c *C) { + tk := testkit.NewTestKit(c, s.store) + tk.MustExec("use test") + tk.MustExec("drop table if exists t") + tk.MustExec("create table t(id int primary key, c1 int, c2 int, c3 int, index idx1(c1, c2), index idx2(c3))") + tk.MustExec("set global tidb_enable_change_multi_schema = off") + tk.MustGetErrMsg("alter table t drop column id", "[ddl:8200]Unsupported drop integer primary key") + tk.MustGetErrMsg("alter table t drop column c1", "[ddl:8200]can't drop column c1 with composite index covered or Primary Key covered now") + tk.MustGetErrMsg("alter table t drop column c3", "[ddl:8200]can't drop column c3 with tidb_enable_change_multi_schema is disable") + tk.MustExec("set global tidb_enable_change_multi_schema = on") + tk.MustExec("alter table t drop column c3") +} + func (s *testSuiteP1) TestIssue28935(c *C) { tk := testkit.NewTestKit(c, s.store) tk.MustExec("set @@tidb_enable_vectorized_expression=true") diff --git a/expression/expr_to_pb_test.go b/expression/expr_to_pb_test.go index 840c7fd4cf7ef..e02e1676d890b 100644 --- a/expression/expr_to_pb_test.go +++ b/expression/expr_to_pb_test.go @@ -525,6 +525,16 @@ func TestExprPushDownToFlash(t *testing.T) { require.NoError(t, err) exprs = append(exprs, function) + // lpad + function, err = NewFunction(mock.NewContext(), ast.Lpad, types.NewFieldType(mysql.TypeString), stringColumn, int32Column, stringColumn) + require.NoError(t, err) + exprs = append(exprs, function) + + // rpad + function, err = NewFunction(mock.NewContext(), ast.Rpad, types.NewFieldType(mysql.TypeString), stringColumn, int32Column, stringColumn) + require.NoError(t, err) + exprs = append(exprs, function) + function, err = NewFunction(mock.NewContext(), ast.If, types.NewFieldType(mysql.TypeLonglong), intColumn, intColumn, intColumn) require.NoError(t, err) exprs = append(exprs, function) @@ -1068,7 +1078,7 @@ func TestExprPushDownToTiKV(t *testing.T) { exprs := make([]Expression, 0) //jsonColumn := genColumn(mysql.TypeJSON, 1) - //intColumn := genColumn(mysql.TypeLonglong, 2) + intColumn := genColumn(mysql.TypeLonglong, 2) //realColumn := genColumn(mysql.TypeDouble, 3) //decimalColumn := genColumn(mysql.TypeNewDecimal, 4) stringColumn := genColumn(mysql.TypeString, 5) @@ -1076,6 +1086,7 @@ func TestExprPushDownToTiKV(t *testing.T) { binaryStringColumn := genColumn(mysql.TypeString, 7) binaryStringColumn.RetType.Collate = charset.CollationBin + // Test exprs that cannot be pushed. function, err := NewFunction(mock.NewContext(), ast.InetAton, types.NewFieldType(mysql.TypeString), stringColumn) require.NoError(t, err) exprs = append(exprs, function) @@ -1111,6 +1122,26 @@ func TestExprPushDownToTiKV(t *testing.T) { pushed, remained := PushDownExprs(sc, exprs, client, kv.TiKV) require.Len(t, pushed, 0) require.Len(t, remained, len(exprs)) + + // Test exprs that can be pushed. + exprs = exprs[:0] + pushed = pushed[:0] + remained = remained[:0] + + substringRelated := []string{ast.Substr, ast.Substring, ast.Mid} + for _, exprName := range substringRelated { + function, err = NewFunction(mock.NewContext(), exprName, types.NewFieldType(mysql.TypeString), stringColumn, intColumn, intColumn) + require.NoError(t, err) + exprs = append(exprs, function) + } + + function, err = NewFunction(mock.NewContext(), ast.CharLength, types.NewFieldType(mysql.TypeString), stringColumn) + require.NoError(t, err) + exprs = append(exprs, function) + + pushed, remained = PushDownExprs(sc, exprs, client, kv.TiKV) + require.Len(t, pushed, len(exprs)) + require.Len(t, remained, 0) } func TestExprOnlyPushDownToTiKV(t *testing.T) { diff --git a/expression/expression.go b/expression/expression.go index eea4f0a604bb5..82c2ea4511d94 100644 --- a/expression/expression.go +++ b/expression/expression.go @@ -967,7 +967,7 @@ func scalarExprSupportedByTiKV(sf *ScalarFunction) bool { // string functions. ast.Length, ast.BitLength, ast.Concat, ast.ConcatWS /*ast.Locate,*/, ast.Replace, ast.ASCII, ast.Hex, ast.Reverse, ast.LTrim, ast.RTrim /*ast.Left,*/, ast.Strcmp, ast.Space, ast.Elt, ast.Field, - InternalFuncFromBinary, InternalFuncToBinary, + InternalFuncFromBinary, InternalFuncToBinary, ast.Mid, ast.Substring, ast.Substr, ast.CharLength, // json functions. ast.JSONType, ast.JSONExtract, ast.JSONObject, ast.JSONArray, ast.JSONMerge, ast.JSONSet, @@ -1048,6 +1048,7 @@ func scalarExprSupportedByFlash(function *ScalarFunction) bool { ast.InetNtoa, ast.InetAton, ast.Inet6Ntoa, ast.Inet6Aton, ast.Coalesce, ast.ASCII, ast.Length, ast.Trim, ast.Position, ast.Format, ast.LTrim, ast.RTrim, + ast.Lpad, ast.Rpad, ast.Hour, ast.Minute, ast.Second, ast.MicroSecond: switch function.Function.PbCode() { case tipb.ScalarFuncSig_InDuration, diff --git a/expression/integration_test.go b/expression/integration_test.go index 9477c5524d0c6..48284410b94ca 100644 --- a/expression/integration_test.go +++ b/expression/integration_test.go @@ -3472,36 +3472,36 @@ func TestExprPushdown(t *testing.T) { "(4,'511111','611',7,8,9),(5,'611111','711',8,9,10)") // case 1, index scan without double read, some filters can not be pushed to cop task - rows := tk.MustQuery("explain format = 'brief' select col2, col1 from t use index(key1) where col2 like '5%' and substr(col1, 1, 1) = '4'").Rows() + rows := tk.MustQuery("explain format = 'brief' select col2, col1 from t use index(key1) where col2 like '5%' and from_base64(to_base64(substr(col1, 1, 1))) = '4'").Rows() require.Equal(t, "root", fmt.Sprintf("%v", rows[1][2])) - require.Equal(t, "eq(substr(test.t.col1, 1, 1), \"4\")", fmt.Sprintf("%v", rows[1][4])) + require.Equal(t, "eq(from_base64(to_base64(substr(test.t.col1, 1, 1))), \"4\")", fmt.Sprintf("%v", rows[1][4])) require.Equal(t, "cop[tikv]", fmt.Sprintf("%v", rows[3][2])) require.Equal(t, "like(test.t.col2, \"5%\", 92)", fmt.Sprintf("%v", rows[3][4])) - tk.MustQuery("select col2, col1 from t use index(key1) where col2 like '5%' and substr(col1, 1, 1) = '4'").Check(testkit.Rows("511 411111")) - tk.MustQuery("select count(col2) from t use index(key1) where col2 like '5%' and substr(col1, 1, 1) = '4'").Check(testkit.Rows("1")) + tk.MustQuery("select col2, col1 from t use index(key1) where col2 like '5%' and from_base64(to_base64(substr(col1, 1, 1))) = '4'").Check(testkit.Rows("511 411111")) + tk.MustQuery("select count(col2) from t use index(key1) where col2 like '5%' and from_base64(to_base64(substr(col1, 1, 1))) = '4'").Check(testkit.Rows("1")) // case 2, index scan without double read, none of the filters can be pushed to cop task - rows = tk.MustQuery("explain format = 'brief' select col1, col2 from t use index(key2) where substr(col2, 1, 1) = '5' and substr(col1, 1, 1) = '4'").Rows() + rows = tk.MustQuery("explain format = 'brief' select col1, col2 from t use index(key2) where from_base64(to_base64(substr(col2, 1, 1))) = '5' and from_base64(to_base64(substr(col1, 1, 1))) = '4'").Rows() require.Equal(t, "root", fmt.Sprintf("%v", rows[0][2])) - require.Equal(t, "eq(substr(test.t.col1, 1, 1), \"4\"), eq(substr(test.t.col2, 1, 1), \"5\")", fmt.Sprintf("%v", rows[0][4])) - tk.MustQuery("select col1, col2 from t use index(key2) where substr(col2, 1, 1) = '5' and substr(col1, 1, 1) = '4'").Check(testkit.Rows("411111 511")) - tk.MustQuery("select count(col1) from t use index(key2) where substr(col2, 1, 1) = '5' and substr(col1, 1, 1) = '4'").Check(testkit.Rows("1")) + require.Equal(t, "eq(from_base64(to_base64(substr(test.t.col1, 1, 1))), \"4\"), eq(from_base64(to_base64(substr(test.t.col2, 1, 1))), \"5\")", fmt.Sprintf("%v", rows[0][4])) + tk.MustQuery("select col1, col2 from t use index(key2) where from_base64(to_base64(substr(col2, 1, 1))) = '5' and from_base64(to_base64(substr(col1, 1, 1))) = '4'").Check(testkit.Rows("411111 511")) + tk.MustQuery("select count(col1) from t use index(key2) where from_base64(to_base64(substr(col2, 1, 1))) = '5' and from_base64(to_base64(substr(col1, 1, 1))) = '4'").Check(testkit.Rows("1")) // case 3, index scan with double read, some filters can not be pushed to cop task - rows = tk.MustQuery("explain format = 'brief' select id from t use index(key1) where col2 like '5%' and substr(col1, 1, 1) = '4'").Rows() + rows = tk.MustQuery("explain format = 'brief' select id from t use index(key1) where col2 like '5%' and from_base64(to_base64(substr(col1, 1, 1))) = '4'").Rows() require.Equal(t, "root", fmt.Sprintf("%v", rows[1][2])) - require.Equal(t, "eq(substr(test.t.col1, 1, 1), \"4\")", fmt.Sprintf("%v", rows[1][4])) + require.Equal(t, "eq(from_base64(to_base64(substr(test.t.col1, 1, 1))), \"4\")", fmt.Sprintf("%v", rows[1][4])) require.Equal(t, "cop[tikv]", fmt.Sprintf("%v", rows[3][2])) require.Equal(t, "like(test.t.col2, \"5%\", 92)", fmt.Sprintf("%v", rows[3][4])) - tk.MustQuery("select id from t use index(key1) where col2 like '5%' and substr(col1, 1, 1) = '4'").Check(testkit.Rows("3")) - tk.MustQuery("select count(id) from t use index(key1) where col2 like '5%' and substr(col1, 1, 1) = '4'").Check(testkit.Rows("1")) + tk.MustQuery("select id from t use index(key1) where col2 like '5%' and from_base64(to_base64(substr(col1, 1, 1))) = '4'").Check(testkit.Rows("3")) + tk.MustQuery("select count(id) from t use index(key1) where col2 like '5%' and from_base64(to_base64(substr(col1, 1, 1))) = '4'").Check(testkit.Rows("1")) // case 4, index scan with double read, none of the filters can be pushed to cop task - rows = tk.MustQuery("explain format = 'brief' select id from t use index(key2) where substr(col2, 1, 1) = '5' and substr(col1, 1, 1) = '4'").Rows() + rows = tk.MustQuery("explain format = 'brief' select id from t use index(key2) where from_base64(to_base64(substr(col2, 1, 1))) = '5' and from_base64(to_base64(substr(col1, 1, 1))) = '4'").Rows() require.Equal(t, "root", fmt.Sprintf("%v", rows[1][2])) - require.Equal(t, "eq(substr(test.t.col1, 1, 1), \"4\"), eq(substr(test.t.col2, 1, 1), \"5\")", fmt.Sprintf("%v", rows[1][4])) - tk.MustQuery("select id from t use index(key2) where substr(col2, 1, 1) = '5' and substr(col1, 1, 1) = '4'").Check(testkit.Rows("3")) - tk.MustQuery("select count(id) from t use index(key2) where substr(col2, 1, 1) = '5' and substr(col1, 1, 1) = '4'").Check(testkit.Rows("1")) + require.Equal(t, "eq(from_base64(to_base64(substr(test.t.col1, 1, 1))), \"4\"), eq(from_base64(to_base64(substr(test.t.col2, 1, 1))), \"5\")", fmt.Sprintf("%v", rows[1][4])) + tk.MustQuery("select id from t use index(key2) where from_base64(to_base64(substr(col2, 1, 1))) = '5' and from_base64(to_base64(substr(col1, 1, 1))) = '4'").Check(testkit.Rows("3")) + tk.MustQuery("select count(id) from t use index(key2) where from_base64(to_base64(substr(col2, 1, 1))) = '5' and from_base64(to_base64(substr(col1, 1, 1))) = '4'").Check(testkit.Rows("1")) } func TestIssue16973(t *testing.T) { store, clean := testkit.CreateMockStore(t) diff --git a/go.mod b/go.mod index bb7699ec0498c..4279f34143727 100644 --- a/go.mod +++ b/go.mod @@ -84,6 +84,7 @@ require ( golang.org/x/sync v0.0.0-20210220032951-036812b2e83c golang.org/x/sys v0.0.0-20211216021012-1d35b9e2eb4e golang.org/x/text v0.3.7 + golang.org/x/time v0.0.0-20191024005414-555d28b269f0 // indirect golang.org/x/tools v0.1.8 google.golang.org/api v0.54.0 google.golang.org/grpc v1.40.0 diff --git a/metrics/grafana/tidb.json b/metrics/grafana/tidb.json index a828967a8ada9..b8d64030406fc 100644 --- a/metrics/grafana/tidb.json +++ b/metrics/grafana/tidb.json @@ -10859,6 +10859,226 @@ "align": false, "alignLevel": null } + }, + { + "aliasColors": {}, + "bars": false, + "dashLength": 10, + "dashes": false, + "datasource": "tidb-cluster", + "description": "", + "fieldConfig": { + "defaults": {}, + "overrides": [] + }, + "fill": 1, + "fillGradient": 0, + "gridPos": { + "h": 7, + "w": 8, + "x": 0, + "y": 33 + }, + "hiddenSeries": false, + "id": 229, + "legend": { + "avg": false, + "current": false, + "max": false, + "min": false, + "show": true, + "total": false, + "values": false + }, + "lines": true, + "linewidth": 1, + "links": [], + "nullPointMode": "null", + "options": { + "alertThreshold": true + }, + "percentage": false, + "pluginVersion": "7.5.11", + "pointradius": 5, + "points": false, + "renderer": "flot", + "seriesOverrides": [], + "spaceLength": 10, + "stack": false, + "steppedLine": false, + "targets": [ + { + "exemplar": true, + "expr": "sum(rate(tidb_statistics_sync_load_total{tidb_cluster=\"$tidb_cluster\"}[1m])) by (type)", + "format": "time_series", + "interval": "", + "intervalFactor": 2, + "legendFormat": "sync-load", + "refId": "A", + "step": 30 + }, + { + "exemplar": true, + "expr": "sum(rate(tidb_statistics_sync_load_timeout_total{tidb_cluster=\"$tidb_cluster\"}[1m])) by (type)", + "format": "time_series", + "hide": false, + "interval": "", + "intervalFactor": 2, + "legendFormat": "timeout", + "refId": "B", + "step": 30 + } + ], + "thresholds": [], + "timeFrom": null, + "timeRegions": [], + "timeShift": null, + "title": "Sync Load QPS", + "tooltip": { + "shared": true, + "sort": 0, + "value_type": "individual" + }, + "type": "graph", + "xaxis": { + "buckets": null, + "mode": "time", + "name": null, + "show": true, + "values": [] + }, + "yaxes": [ + { + "format": "short", + "label": null, + "logBase": 1, + "max": null, + "min": null, + "show": true + }, + { + "format": "short", + "label": null, + "logBase": 1, + "max": null, + "min": null, + "show": true + } + ], + "yaxis": { + "align": false, + "alignLevel": null + } + }, + { + "aliasColors": {}, + "bars": false, + "dashLength": 10, + "dashes": false, + "datasource": "tidb-cluster", + "description": "", + "fieldConfig": { + "defaults": {}, + "overrides": [] + }, + "fill": 1, + "fillGradient": 0, + "gridPos": { + "h": 7, + "w": 8, + "x": 8, + "y": 33 + }, + "hiddenSeries": false, + "id": 230, + "legend": { + "avg": false, + "current": false, + "max": false, + "min": false, + "show": true, + "total": false, + "values": false + }, + "lines": true, + "linewidth": 1, + "links": [], + "nullPointMode": "null", + "options": { + "alertThreshold": true + }, + "percentage": false, + "pluginVersion": "7.5.11", + "pointradius": 5, + "points": false, + "renderer": "flot", + "seriesOverrides": [], + "spaceLength": 10, + "stack": false, + "steppedLine": false, + "targets": [ + { + "exemplar": true, + "expr": "histogram_quantile(0.95, sum(rate(tidb_statistics_sync_load_latency_millis_bucket{tidb_cluster=\"$tidb_cluster\"}[1m])) by (le))", + "format": "time_series", + "interval": "", + "intervalFactor": 2, + "legendFormat": "sync-load", + "refId": "A", + "step": 30 + }, + { + "exemplar": true, + "expr": "histogram_quantile(0.95, sum(rate(tidb_statistics_read_stats_latency_millis_bucket{tidb_cluster=\"$tidb_cluster\"}[1m])) by (le))", + "format": "time_series", + "hide": false, + "interval": "", + "intervalFactor": 2, + "legendFormat": "read-stats", + "refId": "B", + "step": 30 + } + ], + "thresholds": [], + "timeFrom": null, + "timeRegions": [], + "timeShift": null, + "title": "Sync Load Latency 95 (ms)", + "tooltip": { + "shared": true, + "sort": 0, + "value_type": "individual" + }, + "type": "graph", + "xaxis": { + "buckets": null, + "mode": "time", + "name": null, + "show": true, + "values": [] + }, + "yaxes": [ + { + "format": "short", + "label": null, + "logBase": 1, + "max": null, + "min": null, + "show": true + }, + { + "format": "short", + "label": null, + "logBase": 1, + "max": null, + "min": null, + "show": true + } + ], + "yaxis": { + "align": false, + "alignLevel": null + } } ], "repeat": null, diff --git a/metrics/metrics.go b/metrics/metrics.go index 0151987951edc..a42fc1cf8b954 100644 --- a/metrics/metrics.go +++ b/metrics/metrics.go @@ -96,6 +96,10 @@ func RegisterMetrics() { prometheus.MustRegister(HandleJobHistogram) prometheus.MustRegister(SignificantFeedbackCounter) prometheus.MustRegister(FastAnalyzeHistogram) + prometheus.MustRegister(SyncLoadCounter) + prometheus.MustRegister(SyncLoadTimeoutCounter) + prometheus.MustRegister(SyncLoadHistogram) + prometheus.MustRegister(ReadStatsHistogram) prometheus.MustRegister(JobsGauge) prometheus.MustRegister(KeepAliveCounter) prometheus.MustRegister(LoadPrivilegeCounter) diff --git a/metrics/stats.go b/metrics/stats.go index a3347dd597716..c4b74cf088915 100644 --- a/metrics/stats.go +++ b/metrics/stats.go @@ -94,4 +94,38 @@ var ( Help: "Bucketed histogram of some stats in fast analyze.", Buckets: prometheus.ExponentialBuckets(1, 2, 16), }, []string{LblSQLType, LblType}) + + SyncLoadCounter = prometheus.NewCounter( + prometheus.CounterOpts{ + Namespace: "tidb", + Subsystem: "statistics", + Name: "sync_load_total", + Help: "Counter of sync load.", + }) + + SyncLoadTimeoutCounter = prometheus.NewCounter( + prometheus.CounterOpts{ + Namespace: "tidb", + Subsystem: "statistics", + Name: "sync_load_timeout_total", + Help: "Counter of sync load timeout.", + }) + + SyncLoadHistogram = prometheus.NewHistogram( + prometheus.HistogramOpts{ + Namespace: "tidb", + Subsystem: "statistics", + Name: "sync_load_latency_millis", + Help: "Bucketed histogram of latency time (ms) of sync load.", + Buckets: prometheus.ExponentialBuckets(1, 2, 22), // 1ms ~ 1h + }) + + ReadStatsHistogram = prometheus.NewHistogram( + prometheus.HistogramOpts{ + Namespace: "tidb", + Subsystem: "statistics", + Name: "read_stats_latency_millis", + Help: "Bucketed histogram of latency time (ms) of stats read during sync-load.", + Buckets: prometheus.ExponentialBuckets(1, 2, 22), // 1ms ~ 1h + }) ) diff --git a/planner/core/integration_test.go b/planner/core/integration_test.go index 95adfa8da44fa..847af9161c092 100644 --- a/planner/core/integration_test.go +++ b/planner/core/integration_test.go @@ -5004,72 +5004,40 @@ func (s *testIntegrationSuite) TestIssue30200(c *C) { tk.MustExec("create table t1(c1 varchar(100), c2 varchar(100), key(c1), key(c2), c3 varchar(100));") tk.MustExec("insert into t1 values('ab', '10', '10');") - // lpad has not been pushed to TiKV or TiFlash. - tk.MustQuery("explain format=brief select /*+ use_index_merge(t1) */ * from t1 where c1 = 'ab' or c2 = '10' and char_length(lpad(c1, 10, 'a')) = 10;").Check(testkit.Rows( - "Selection 15.99 root or(eq(test.t1.c1, \"ab\"), and(eq(test.t1.c2, \"10\"), eq(char_length(lpad(test.t1.c1, 10, \"a\")), 10)))", - "└─IndexMerge 19.99 root ", - " ├─IndexRangeScan(Build) 10.00 cop[tikv] table:t1, index:c1(c1) range:[\"ab\",\"ab\"], keep order:false, stats:pseudo", - " ├─IndexRangeScan(Build) 10.00 cop[tikv] table:t1, index:c2(c2) range:[\"10\",\"10\"], keep order:false, stats:pseudo", - " └─TableRowIDScan(Probe) 19.99 cop[tikv] table:t1 keep order:false, stats:pseudo")) - tk.MustQuery("select /*+ use_index_merge(t1) */ 1 from t1 where c1 = 'de' or c2 = '10' and char_length(lpad(c1, 10, 'a')) = 10;").Check(testkit.Rows("1")) - - // `left` has not been pushed to TiKV, but it has been pushed to TiFlash. - tk.MustQuery("explain format=brief select /*+ use_index_merge(t1) */ * from t1 where c1 = 'ab' or c2 = '10' and char_length(left(c1, 10)) = 10;").Check(testkit.Rows( - "Selection 0.04 root or(eq(test.t1.c1, \"ab\"), and(eq(test.t1.c2, \"10\"), eq(char_length(left(test.t1.c1, 10)), 10)))", - "└─IndexMerge 19.99 root ", - " ├─IndexRangeScan(Build) 10.00 cop[tikv] table:t1, index:c1(c1) range:[\"ab\",\"ab\"], keep order:false, stats:pseudo", - " ├─IndexRangeScan(Build) 10.00 cop[tikv] table:t1, index:c2(c2) range:[\"10\",\"10\"], keep order:false, stats:pseudo", - " └─TableRowIDScan(Probe) 19.99 cop[tikv] table:t1 keep order:false, stats:pseudo")) - tk.MustQuery("select /*+ use_index_merge(t1) */ 1 from t1 where c1 = 'ab' or c2 = '10' and char_length(left(c1, 10)) = 10;").Check(testkit.Rows("1")) - - // If no hint, we cannot use index merge if filter cannot be pushed to any storage. + tk.MustExec("drop table if exists tt1;") + tk.MustExec("create table tt1(c1 varchar(100), c2 varchar(100), c3 varchar(100), c4 varchar(100), key idx_0(c1), key idx_1(c2, c3));") + tk.MustExec("insert into tt1 values('ab', '10', '10', '10');") + + tk.MustExec("drop table if exists tt2;") + tk.MustExec("create table tt2 (c1 int , pk int, primary key( pk ) , unique key( c1));") + tk.MustExec("insert into tt2 values(-3896405, -1), (-2, 1), (-1, -2);") + + tk.MustExec("drop table if exists tt3;") + tk.MustExec("create table tt3(c1 int, c2 int, c3 int as (c1 + c2), key(c1), key(c2), key(c3));") + tk.MustExec("insert into tt3(c1, c2) values(1, 1);") + oriIndexMergeSwitcher := tk.MustQuery("select @@tidb_enable_index_merge;").Rows()[0][0].(string) tk.MustExec("set tidb_enable_index_merge = on;") defer func() { tk.MustExec(fmt.Sprintf("set tidb_enable_index_merge = %s;", oriIndexMergeSwitcher)) }() - tk.MustQuery("explain format=brief select * from t1 where c1 = 'ab' or c2 = '10' and char_length(lpad(c1, 10, 'a')) = 10;").Check(testkit.Rows( - "Selection 8000.00 root or(eq(test.t1.c1, \"ab\"), and(eq(test.t1.c2, \"10\"), eq(char_length(lpad(test.t1.c1, 10, \"a\")), 10)))", - "└─TableReader 10000.00 root data:TableFullScan", - " └─TableFullScan 10000.00 cop[tikv] table:t1 keep order:false, stats:pseudo")) - tk.MustExec("use test") - tk.MustExec("drop table if exists t1;") - tk.MustExec("create table t1(c1 varchar(100), c2 varchar(100), c3 varchar(100), c4 varchar(100), key idx_0(c1), key idx_1(c2, c3));") - tk.MustExec("insert into t1 values('ab', '10', '10', '10');") - // c3 is part of idx_1, so it will be put in partial_path's IndexFilters instead of TableFilters. - // But it still cannot be pushed to TiKV. - tk.MustQuery("explain select /*+ use_index_merge(t1) */ 1 from t1 where c1 = 'de' or c2 = '10' and char_length(lpad(c3, 10, 'a')) = 10;").Check(testkit.Rows( - "Projection_4 15.99 root 1->Column#6", - "└─Selection_5 15.99 root or(eq(test.t1.c1, \"de\"), and(eq(test.t1.c2, \"10\"), eq(char_length(lpad(test.t1.c3, 10, \"a\")), 10)))", - " └─IndexMerge_9 19.99 root ", - " ├─IndexRangeScan_6(Build) 10.00 cop[tikv] table:t1, index:idx_0(c1) range:[\"de\",\"de\"], keep order:false, stats:pseudo", - " ├─IndexRangeScan_7(Build) 10.00 cop[tikv] table:t1, index:idx_1(c2, c3) range:[\"10\",\"10\"], keep order:false, stats:pseudo", - " └─TableRowIDScan_8(Probe) 19.99 cop[tikv] table:t1 keep order:false, stats:pseudo")) - tk.MustQuery("select /*+ use_index_merge(t1) */ 1 from t1 where c1 = 'de' or c2 = '10' and char_length(lpad(c3, 10, 'a')) = 10;").Check(testkit.Rows("1")) - - tk.MustExec("drop table if exists t1;") - tk.MustExec("create table t1 (c1 int , pk int, primary key( pk ) , unique key( c1));") - tk.MustExec("insert into t1 values(-3896405, -1), (-2, 1), (-1, -2);") - // to_base64(left(pk, 5)) is in partial_path's TableFilters. But it cannot be pushed to TiKV. So it should be executed in TiDB. - tk.MustQuery("explain select /*+ use_index_merge( t1 ) */ * from t1 where t1.c1 in (-3896405) or t1.pk in (1, 53330) and to_base64(left(pk, 5));").Check(testkit.Rows( - "Selection_5 2.40 root or(eq(test.t1.c1, -3896405), and(in(test.t1.pk, 1, 53330), istrue_with_null(cast(to_base64(left(cast(test.t1.pk, var_string(20)), 5)), double BINARY))))", - "└─IndexMerge_9 3.00 root ", - " ├─IndexRangeScan_6(Build) 1.00 cop[tikv] table:t1, index:c1(c1) range:[-3896405,-3896405], keep order:false, stats:pseudo", - " ├─TableRangeScan_7(Build) 2.00 cop[tikv] table:t1 range:[1,1], [53330,53330], keep order:false, stats:pseudo", - " └─TableRowIDScan_8(Probe) 3.00 cop[tikv] table:t1 keep order:false, stats:pseudo")) - tk.MustQuery("select /*+ use_index_merge( t1 ) */ * from t1 where t1.c1 in (-3896405) or t1.pk in (1, 53330) and to_base64(left(pk, 5));").Check(testkit.Rows("-3896405 -1")) - - tk.MustExec("drop table if exists t1;") - tk.MustExec("create table t1(c1 int, c2 int, c3 int as (c1 + c2), key(c1), key(c2), key(c3));") - tk.MustExec("insert into t1(c1, c2) values(1, 1);") - tk.MustQuery("explain format=brief select /*+ use_index_merge(t1) */ * from t1 where c1 < -10 or c2 < 10 and reverse(c3) = '2';").Check(testkit.Rows( - "Selection 2825.66 root or(lt(test.t1.c1, -10), and(lt(test.t1.c2, 10), eq(reverse(cast(test.t1.c3, var_string(20))), \"2\")))", - "└─IndexMerge 5542.21 root ", - " ├─IndexRangeScan(Build) 3323.33 cop[tikv] table:t1, index:c1(c1) range:[-inf,-10), keep order:false, stats:pseudo", - " ├─IndexRangeScan(Build) 3323.33 cop[tikv] table:t1, index:c2(c2) range:[-inf,10), keep order:false, stats:pseudo", - " └─TableRowIDScan(Probe) 5542.21 cop[tikv] table:t1 keep order:false, stats:pseudo")) - tk.MustQuery("select /*+ use_index_merge(t1) */ * from t1 where c1 < -10 or c2 < 10 and reverse(c3) = '2';").Check(testkit.Rows("1 1 2")) + var input []string + var output []struct { + SQL string + Plan []string + Res []string + } + s.testData.GetTestCases(c, &input, &output) + for i, tt := range input { + s.testData.OnRecord(func() { + output[i].SQL = tt + output[i].Plan = s.testData.ConvertRowsToStrings(tk.MustQuery("explain format=brief " + tt).Rows()) + output[i].Res = s.testData.ConvertRowsToStrings(tk.MustQuery(tt).Rows()) + }) + tk.MustQuery("explain format=brief " + tt).Check(testkit.Rows(output[i].Plan...)) + tk.MustQuery(tt).Check(testkit.Rows(output[i].Res...)) + } } func (s *testIntegrationSuite) TestIssue29705(c *C) { diff --git a/planner/core/optimizer.go b/planner/core/optimizer.go index 0f526ece29268..5cdd693a7e77e 100644 --- a/planner/core/optimizer.go +++ b/planner/core/optimizer.go @@ -62,8 +62,10 @@ const ( flagPredicatePushDown flagEliminateOuterJoin flagPartitionProcessor + flagCollectPredicateColumnsPoint flagPushDownAgg flagPushDownTopN + flagSyncWaitStatsLoadPoint flagJoinReOrder flagPrunColumnsAgain ) @@ -80,8 +82,10 @@ var optRuleList = []logicalOptRule{ &ppdSolver{}, &outerJoinEliminator{}, &partitionProcessor{}, + &collectPredicateColumnsPoint{}, &aggregationPushDownSolver{}, &pushDownTopNOptimizer{}, + &syncWaitStatsLoadPoint{}, &joinReOrderSolver{}, &columnPruner{}, // column pruning again at last, note it will mess up the results of buildKeySolver } @@ -257,11 +261,6 @@ func checkStableResultMode(sctx sessionctx.Context) bool { // DoOptimize optimizes a logical plan to a physical plan. func DoOptimize(ctx context.Context, sctx sessionctx.Context, flag uint64, logic LogicalPlan) (PhysicalPlan, float64, error) { - // TODO: move it to the logic of sync load hist-needed columns. - if variable.EnableColumnTracking.Load() { - predicateColumns, _ := CollectColumnStatsUsage(logic, true, false) - sctx.UpdateColStatsUsage(predicateColumns) - } // if there is something after flagPrunColumns, do flagPrunColumnsAgain if flag&flagPrunColumns > 0 && flag-flagPrunColumns > flagPrunColumns { flag |= flagPrunColumnsAgain @@ -269,6 +268,8 @@ func DoOptimize(ctx context.Context, sctx sessionctx.Context, flag uint64, logic if checkStableResultMode(sctx) { flag |= flagStabilizeResults } + flag |= flagCollectPredicateColumnsPoint + flag |= flagSyncWaitStatsLoadPoint logic, err := logicalOptimize(ctx, flag, logic) if err != nil { return nil, 0, err diff --git a/planner/core/plan_stats.go b/planner/core/plan_stats.go new file mode 100644 index 0000000000000..f1100061b9f3b --- /dev/null +++ b/planner/core/plan_stats.go @@ -0,0 +1,112 @@ +// Copyright 2021 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package core + +import ( + "context" + "time" + + "github.com/cznic/mathutil" + "github.com/pingcap/errors" + "github.com/pingcap/tidb/domain" + "github.com/pingcap/tidb/parser/model" + "github.com/pingcap/tidb/sessionctx" + "github.com/pingcap/tidb/sessionctx/stmtctx" + "github.com/pingcap/tidb/sessionctx/variable" +) + +type collectPredicateColumnsPoint struct{} + +func (c collectPredicateColumnsPoint) optimize(ctx context.Context, plan LogicalPlan, op *logicalOptimizeOp) (LogicalPlan, error) { + if plan.SCtx().GetSessionVars().InRestrictedSQL { + return plan, nil + } + predicateNeeded := variable.EnableColumnTracking.Load() + syncWait := plan.SCtx().GetSessionVars().StatsLoadSyncWait * time.Millisecond.Nanoseconds() + histNeeded := syncWait > 0 + predicateColumns, histNeededColumns := CollectColumnStatsUsage(plan, predicateNeeded, histNeeded) + if len(predicateColumns) > 0 { + plan.SCtx().UpdateColStatsUsage(predicateColumns) + } + if histNeeded && len(histNeededColumns) > 0 { + err := RequestLoadColumnStats(plan.SCtx(), histNeededColumns, syncWait) + return plan, err + } + return plan, nil +} + +func (c collectPredicateColumnsPoint) name() string { + return "collect_predicate_columns_point" +} + +type syncWaitStatsLoadPoint struct{} + +func (s syncWaitStatsLoadPoint) optimize(ctx context.Context, plan LogicalPlan, op *logicalOptimizeOp) (LogicalPlan, error) { + if plan.SCtx().GetSessionVars().InRestrictedSQL { + return plan, nil + } + _, err := SyncWaitStatsLoad(plan) + return plan, err +} + +func (s syncWaitStatsLoadPoint) name() string { + return "sync_wait_stats_load_point" +} + +const maxDuration = 1<<63 - 1 + +// RequestLoadColumnStats send requests to stats handle +func RequestLoadColumnStats(ctx sessionctx.Context, neededColumns []model.TableColumnID, syncWait int64) error { + stmtCtx := ctx.GetSessionVars().StmtCtx + hintMaxExecutionTime := int64(stmtCtx.MaxExecutionTime) + if hintMaxExecutionTime <= 0 { + hintMaxExecutionTime = maxDuration + } + sessMaxExecutionTime := int64(ctx.GetSessionVars().MaxExecutionTime) + if sessMaxExecutionTime <= 0 { + sessMaxExecutionTime = maxDuration + } + waitTime := mathutil.MinInt64(syncWait, mathutil.MinInt64(hintMaxExecutionTime, sessMaxExecutionTime)) + var timeout = time.Duration(waitTime) + err := domain.GetDomain(ctx).StatsHandle().SendLoadRequests(stmtCtx, neededColumns, timeout) + if err != nil { + return handleTimeout(stmtCtx) + } + return nil +} + +// SyncWaitStatsLoad sync-wait for stats load until timeout +func SyncWaitStatsLoad(plan LogicalPlan) (bool, error) { + stmtCtx := plan.SCtx().GetSessionVars().StmtCtx + if stmtCtx.StatsLoad.Fallback { + return false, nil + } + success := domain.GetDomain(plan.SCtx()).StatsHandle().SyncWaitStatsLoad(stmtCtx) + if success { + return true, nil + } + err := handleTimeout(stmtCtx) + return false, err +} + +func handleTimeout(stmtCtx *stmtctx.StatementContext) error { + err := errors.New("Timeout when sync-load full stats for needed columns") + if variable.StatsLoadPseudoTimeout.Load() { + stmtCtx.AppendWarning(err) + stmtCtx.StatsLoad.Fallback = true + return nil + } + return err +} diff --git a/planner/core/plan_stats_test.go b/planner/core/plan_stats_test.go new file mode 100644 index 0000000000000..63396ab7ce3f0 --- /dev/null +++ b/planner/core/plan_stats_test.go @@ -0,0 +1,292 @@ +// Copyright 2021 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package core_test + +import ( + "context" + "fmt" + "time" + + . "github.com/pingcap/check" + "github.com/pingcap/tidb/config" + "github.com/pingcap/tidb/executor" + "github.com/pingcap/tidb/parser" + "github.com/pingcap/tidb/parser/model" + "github.com/pingcap/tidb/planner" + plannercore "github.com/pingcap/tidb/planner/core" + "github.com/pingcap/tidb/sessionctx" + "github.com/pingcap/tidb/statistics" + "github.com/pingcap/tidb/util/testkit" + "github.com/pingcap/tidb/util/testleak" +) + +var _ = Suite(&testPlanStatsSuite{}) + +type testPlanStatsSuite struct { + *parser.Parser +} + +func (s *testPlanStatsSuite) SetUpSuite(c *C) { + s.Parser = parser.New() +} + +func (s *testPlanStatsSuite) TearDownSuite(c *C) { +} + +func (s *testPlanStatsSuite) TestPlanStatsLoad(c *C) { + defer testleak.AfterTest(c)() + store, dom, err := newStoreWithBootstrap() + c.Check(err, IsNil) + defer func() { + dom.Close() + store.Close() + }() + tk := testkit.NewTestKit(c, store) + tk.MustExec("use test") + ctx := tk.Se.(sessionctx.Context) + tk.MustExec("drop table if exists t") + tk.MustExec("set @@session.tidb_analyze_version=2") + tk.MustExec("set @@session.tidb_partition_prune_mode = 'static'") + tk.MustExec("set @@session.tidb_stats_load_sync_wait =9999999") + tk.MustExec("create table t(a int, b int, c int, d int, primary key(a), key idx(b))") + tk.MustExec("insert into t values (1,1,1,1),(2,2,2,2),(3,3,3,3)") + tk.MustExec("create table pt(a int, b int, c int) partition by range(a) (partition p0 values less than (10), partition p1 values less than (20), partition p2 values less than maxvalue)") + tk.MustExec("insert into pt values (1,1,1),(2,2,2),(13,13,13),(14,14,14),(25,25,25),(36,36,36)") + + oriLease := dom.StatsHandle().Lease() + dom.StatsHandle().SetLease(1) + defer func() { + dom.StatsHandle().SetLease(oriLease) + }() + tk.MustExec("analyze table t") + tk.MustExec("analyze table pt") + + testCases := []struct { + sql string + skip bool + check func(p plannercore.Plan, tableInfo *model.TableInfo) + }{ + { // DataSource + sql: "select * from t where c>1", + check: func(p plannercore.Plan, tableInfo *model.TableInfo) { + switch pp := p.(type) { + case *plannercore.PhysicalTableReader: + stats := pp.Stats().HistColl + c.Assert(countFullStats(stats, tableInfo.Columns[1].ID), Equals, 0) + c.Assert(countFullStats(stats, tableInfo.Columns[2].ID), Greater, 0) + default: + c.Error("unexpected plan:", pp) + } + }, + }, + { // PartitionTable + sql: "select * from pt where a < 15 and c > 1", + check: func(p plannercore.Plan, tableInfo *model.TableInfo) { + pua, ok := p.(*plannercore.PhysicalUnionAll) + c.Check(ok, IsTrue) + for _, child := range pua.Children() { + c.Assert(countFullStats(child.Stats().HistColl, tableInfo.Columns[2].ID), Greater, 0) + } + }, + }, + { // Join + sql: "select * from t t1 inner join t t2 on t1.b=t2.b where t1.d=3", + check: func(p plannercore.Plan, tableInfo *model.TableInfo) { + pp, ok := p.(plannercore.PhysicalPlan) + c.Check(ok, IsTrue) + c.Assert(countFullStats(pp.Children()[0].Stats().HistColl, tableInfo.Columns[3].ID), Greater, 0) + c.Assert(countFullStats(pp.Children()[1].Stats().HistColl, tableInfo.Columns[3].ID), Greater, 0) + }, + }, + { // Apply + sql: "select * from t t1 where t1.b > (select count(*) from t t2 where t2.c > t1.a and t2.d>1) and t1.c>2", + check: func(p plannercore.Plan, tableInfo *model.TableInfo) { + pp, ok := p.(*plannercore.PhysicalProjection) + c.Check(ok, IsTrue) + pa, ok := pp.Children()[0].(*plannercore.PhysicalApply) + c.Check(ok, IsTrue) + left := pa.PhysicalHashJoin.Children()[0] + right := pa.PhysicalHashJoin.Children()[0] + c.Assert(countFullStats(left.Stats().HistColl, tableInfo.Columns[2].ID), Greater, 0) + c.Assert(countFullStats(right.Stats().HistColl, tableInfo.Columns[3].ID), Greater, 0) + }, + }, + { // > Any + sql: "select * from t where t.b > any(select d from t where t.c > 2)", + check: func(p plannercore.Plan, tableInfo *model.TableInfo) { + ph, ok := p.(*plannercore.PhysicalHashJoin) + c.Check(ok, IsTrue) + ptr, ok := ph.Children()[0].(*plannercore.PhysicalTableReader) + c.Check(ok, IsTrue) + c.Assert(countFullStats(ptr.Stats().HistColl, tableInfo.Columns[2].ID), Greater, 0) + }, + }, + { // in + sql: "select * from t where t.b in (select d from t where t.c > 2)", + check: func(p plannercore.Plan, tableInfo *model.TableInfo) { + ph, ok := p.(*plannercore.PhysicalHashJoin) + c.Check(ok, IsTrue) + ptr, ok := ph.Children()[1].(*plannercore.PhysicalTableReader) + c.Check(ok, IsTrue) + c.Assert(countFullStats(ptr.Stats().HistColl, tableInfo.Columns[2].ID), Greater, 0) + }, + }, + { // not in + sql: "select * from t where t.b not in (select d from t where t.c > 2)", + check: func(p plannercore.Plan, tableInfo *model.TableInfo) { + ph, ok := p.(*plannercore.PhysicalHashJoin) + c.Check(ok, IsTrue) + ptr, ok := ph.Children()[1].(*plannercore.PhysicalTableReader) + c.Check(ok, IsTrue) + c.Assert(countFullStats(ptr.Stats().HistColl, tableInfo.Columns[2].ID), Greater, 0) + }, + }, + { // exists + sql: "select * from t t1 where exists (select * from t t2 where t1.b > t2.d and t2.c>1)", + check: func(p plannercore.Plan, tableInfo *model.TableInfo) { + ph, ok := p.(*plannercore.PhysicalHashJoin) + c.Check(ok, IsTrue) + ptr, ok := ph.Children()[1].(*plannercore.PhysicalTableReader) + c.Check(ok, IsTrue) + c.Assert(countFullStats(ptr.Stats().HistColl, tableInfo.Columns[2].ID), Greater, 0) + }, + }, + { // not exists + sql: "select * from t t1 where not exists (select * from t t2 where t1.b > t2.d and t2.c>1)", + check: func(p plannercore.Plan, tableInfo *model.TableInfo) { + ph, ok := p.(*plannercore.PhysicalHashJoin) + c.Check(ok, IsTrue) + ptr, ok := ph.Children()[1].(*plannercore.PhysicalTableReader) + c.Check(ok, IsTrue) + c.Assert(countFullStats(ptr.Stats().HistColl, tableInfo.Columns[2].ID), Greater, 0) + }, + }, + { // CTE + sql: "with cte(x, y) as (select d + 1, b from t where c > 1) select * from cte where x < 3", + check: func(p plannercore.Plan, tableInfo *model.TableInfo) { + ps, ok := p.(*plannercore.PhysicalSelection) + c.Check(ok, IsTrue) + pc, ok := ps.Children()[0].(*plannercore.PhysicalCTE) + c.Check(ok, IsTrue) + pp, ok := pc.SeedPlan.(*plannercore.PhysicalProjection) + c.Check(ok, IsTrue) + reader, ok := pp.Children()[0].(*plannercore.PhysicalTableReader) + c.Check(ok, IsTrue) + c.Assert(countFullStats(reader.Stats().HistColl, tableInfo.Columns[2].ID), Greater, 0) + }, + }, + { // recursive CTE + sql: "with recursive cte(x, y) as (select a, b from t where c > 1 union select x + 1, y from cte where x < 5) select * from cte", + check: func(p plannercore.Plan, tableInfo *model.TableInfo) { + pc, ok := p.(*plannercore.PhysicalCTE) + c.Check(ok, IsTrue) + pp, ok := pc.SeedPlan.(*plannercore.PhysicalProjection) + c.Check(ok, IsTrue) + reader, ok := pp.Children()[0].(*plannercore.PhysicalTableReader) + c.Check(ok, IsTrue) + c.Assert(countFullStats(reader.Stats().HistColl, tableInfo.Columns[2].ID), Greater, 0) + }, + }, + } + for _, testCase := range testCases { + if testCase.skip { + continue + } + is := dom.InfoSchema() + dom.StatsHandle().Clear() // clear statsCache + c.Assert(dom.StatsHandle().Update(is), IsNil) + stmt, err := s.ParseOneStmt(testCase.sql, "", "") + c.Check(err, IsNil) + err = executor.ResetContextOfStmt(ctx, stmt) + c.Assert(err, IsNil) + p, _, err := planner.Optimize(context.TODO(), ctx, stmt, is) + c.Check(err, IsNil) + tbl, err := is.TableByName(model.NewCIStr("test"), model.NewCIStr("t")) + c.Assert(err, IsNil) + tableInfo := tbl.Meta() + testCase.check(p, tableInfo) + } +} + +func countFullStats(stats *statistics.HistColl, colID int64) int { + for _, col := range stats.Columns { + if col.Info.ID == colID { + return col.Histogram.Len() + col.TopN.Num() + } + } + return -1 +} + +func (s *testPlanStatsSuite) TestPlanStatsLoadTimeout(c *C) { + originConfig := config.GetGlobalConfig() + newConfig := config.NewConfig() + newConfig.Performance.StatsLoadConcurrency = 0 // no worker to consume channel + newConfig.Performance.StatsLoadQueueSize = 1 + config.StoreGlobalConfig(newConfig) + defer config.StoreGlobalConfig(originConfig) + defer testleak.AfterTest(c)() + store, dom, err := newStoreWithBootstrap() + c.Check(err, IsNil) + defer func() { + dom.Close() + store.Close() + }() + tk := testkit.NewTestKit(c, store) + tk.MustExec("use test") + originalVal1 := tk.MustQuery("select @@tidb_stats_load_pseudo_timeout").Rows()[0][0].(string) + defer func() { + tk.MustExec(fmt.Sprintf("set global tidb_stats_load_pseudo_timeout = %v", originalVal1)) + }() + + ctx := tk.Se.(sessionctx.Context) + tk.MustExec("drop table if exists t") + tk.MustExec("set @@session.tidb_analyze_version=2") + // since queue full, make sync-wait return as timeout as soon as possible + tk.MustExec("set @@session.tidb_stats_load_sync_wait = 1") + tk.MustExec("create table t(a int, b int, c int, primary key(a))") + tk.MustExec("insert into t values (1,1,1),(2,2,2),(3,3,3)") + + oriLease := dom.StatsHandle().Lease() + dom.StatsHandle().SetLease(1) + defer func() { + dom.StatsHandle().SetLease(oriLease) + }() + tk.MustExec("analyze table t") + is := dom.InfoSchema() + c.Assert(dom.StatsHandle().Update(is), IsNil) + tbl, err := is.TableByName(model.NewCIStr("test"), model.NewCIStr("t")) + c.Assert(err, IsNil) + tableInfo := tbl.Meta() + neededColumn := model.TableColumnID{TableID: tableInfo.ID, ColumnID: tableInfo.Columns[0].ID} + resultCh := make(chan model.TableColumnID, 1) + timeout := time.Duration(1<<63 - 1) + dom.StatsHandle().AppendNeededColumn(neededColumn, resultCh, timeout) // make channel queue full + stmt, err := s.ParseOneStmt("select * from t where c>1", "", "") + c.Check(err, IsNil) + tk.MustExec("set global tidb_stats_load_pseudo_timeout=false") + _, _, err = planner.Optimize(context.TODO(), ctx, stmt, is) + c.Check(err, NotNil) // fail sql for timeout when pseudo=false + tk.MustExec("set global tidb_stats_load_pseudo_timeout=true") + plan, _, err := planner.Optimize(context.TODO(), ctx, stmt, is) + c.Check(err, IsNil) // not fail sql for timeout when pseudo=true + switch pp := plan.(type) { + case *plannercore.PhysicalTableReader: + stats := pp.Stats().HistColl + c.Assert(countFullStats(stats, tableInfo.Columns[0].ID), Greater, 0) + c.Assert(countFullStats(stats, tableInfo.Columns[2].ID), Equals, 0) // pseudo stats + default: + c.Error("unexpected plan:", pp) + } +} diff --git a/planner/core/testdata/integration_suite_in.json b/planner/core/testdata/integration_suite_in.json index 1be2b5baceaec..69d094c9b3c4c 100644 --- a/planner/core/testdata/integration_suite_in.json +++ b/planner/core/testdata/integration_suite_in.json @@ -382,6 +382,32 @@ "select a from ta group by @n:=@n+a" ] }, + { + "name": "TestIssue30200", + "cases": [ + // to_base64 and from_base64 has not been pushed to TiKV or TiFlash. + // We expect a Selection will be added above IndexMerge. + "select /*+ use_index_merge(t1) */ 1 from t1 where c1 = 'de' or c2 = '10' and from_base64(to_base64(c1)) = 'ab';", + + // `left` has not been pushed to TiKV, but it has been pushed to TiFlash. + // We expect a Selection will be added above IndexMerge. + "select /*+ use_index_merge(t1) */ 1 from t1 where c1 = 'ab' or c2 = '10' and char_length(left(c1, 10)) = 10;", + + // c3 is part of idx_1, so it will be put in partial_path's IndexFilters instead of TableFilters. + // But it still cannot be pushed to TiKV. This case cover code in DataSource.buildIndexMergeOrPath. + "select /*+ use_index_merge(tt1) */ 1 from tt1 where c1 = 'de' or c2 = '10' and from_base64(to_base64(c3)) = '10';", + + // to_base64(left(pk, 5)) is in partial_path's TableFilters. But it cannot be pushed to TiKV. + // So it should be executed in TiDB. This case cover code in DataSource.buildIndexMergeOrPath. + "select /*+ use_index_merge( tt2 ) */ 1 from tt2 where tt2.c1 in (-3896405) or tt2.pk in (1, 53330) and to_base64(left(pk, 5));", + + // This case covert expression index. + "select /*+ use_index_merge(tt3) */ 1 from tt3 where c1 < -10 or c2 < 10 and reverse(c3) = '2';", + + // If no hint, we cannot use index merge if filter cannot be pushed to any storage. + "select 1 from t1 where c1 = 'de' or c2 = '10' and from_base64(to_base64(c1)) = 'ab';" + ] + }, { "name": "TestIndexMergeWithCorrelatedColumns", "cases": [ diff --git a/planner/core/testdata/integration_suite_out.json b/planner/core/testdata/integration_suite_out.json index f4726f876c70d..5efa0586a2cd9 100644 --- a/planner/core/testdata/integration_suite_out.json +++ b/planner/core/testdata/integration_suite_out.json @@ -2056,6 +2056,93 @@ } ] }, + { + "Name": "TestIssue30200", + "Cases": [ + { + "SQL": "select /*+ use_index_merge(t1) */ 1 from t1 where c1 = 'de' or c2 = '10' and from_base64(to_base64(c1)) = 'ab';", + "Plan": [ + "Projection 15.99 root 1->Column#5", + "└─Selection 15.99 root or(eq(test.t1.c1, \"de\"), and(eq(test.t1.c2, \"10\"), eq(from_base64(to_base64(test.t1.c1)), \"ab\")))", + " └─IndexMerge 19.99 root ", + " ├─IndexRangeScan(Build) 10.00 cop[tikv] table:t1, index:c1(c1) range:[\"de\",\"de\"], keep order:false, stats:pseudo", + " ├─IndexRangeScan(Build) 10.00 cop[tikv] table:t1, index:c2(c2) range:[\"10\",\"10\"], keep order:false, stats:pseudo", + " └─TableRowIDScan(Probe) 19.99 cop[tikv] table:t1 keep order:false, stats:pseudo" + ], + "Res": [ + "1" + ] + }, + { + "SQL": "select /*+ use_index_merge(t1) */ 1 from t1 where c1 = 'ab' or c2 = '10' and char_length(left(c1, 10)) = 10;", + "Plan": [ + "Projection 17.99 root 1->Column#5", + "└─Selection 0.04 root or(eq(test.t1.c1, \"ab\"), and(eq(test.t1.c2, \"10\"), eq(char_length(left(test.t1.c1, 10)), 10)))", + " └─IndexMerge 19.99 root ", + " ├─IndexRangeScan(Build) 10.00 cop[tikv] table:t1, index:c1(c1) range:[\"ab\",\"ab\"], keep order:false, stats:pseudo", + " ├─IndexRangeScan(Build) 10.00 cop[tikv] table:t1, index:c2(c2) range:[\"10\",\"10\"], keep order:false, stats:pseudo", + " └─TableRowIDScan(Probe) 19.99 cop[tikv] table:t1 keep order:false, stats:pseudo" + ], + "Res": [ + "1" + ] + }, + { + "SQL": "select /*+ use_index_merge(tt1) */ 1 from tt1 where c1 = 'de' or c2 = '10' and from_base64(to_base64(c3)) = '10';", + "Plan": [ + "Projection 15.99 root 1->Column#6", + "└─Selection 15.99 root or(eq(test.tt1.c1, \"de\"), and(eq(test.tt1.c2, \"10\"), eq(from_base64(to_base64(test.tt1.c3)), \"10\")))", + " └─IndexMerge 19.99 root ", + " ├─IndexRangeScan(Build) 10.00 cop[tikv] table:tt1, index:idx_0(c1) range:[\"de\",\"de\"], keep order:false, stats:pseudo", + " ├─IndexRangeScan(Build) 10.00 cop[tikv] table:tt1, index:idx_1(c2, c3) range:[\"10\",\"10\"], keep order:false, stats:pseudo", + " └─TableRowIDScan(Probe) 19.99 cop[tikv] table:tt1 keep order:false, stats:pseudo" + ], + "Res": [ + "1" + ] + }, + { + "SQL": "select /*+ use_index_merge( tt2 ) */ 1 from tt2 where tt2.c1 in (-3896405) or tt2.pk in (1, 53330) and to_base64(left(pk, 5));", + "Plan": [ + "Projection 2.40 root 1->Column#3", + "└─Selection 2.40 root or(eq(test.tt2.c1, -3896405), and(in(test.tt2.pk, 1, 53330), istrue_with_null(cast(to_base64(left(cast(test.tt2.pk, var_string(20)), 5)), double BINARY))))", + " └─IndexMerge 3.00 root ", + " ├─IndexRangeScan(Build) 1.00 cop[tikv] table:tt2, index:c1(c1) range:[-3896405,-3896405], keep order:false, stats:pseudo", + " ├─TableRangeScan(Build) 2.00 cop[tikv] table:tt2 range:[1,1], [53330,53330], keep order:false, stats:pseudo", + " └─TableRowIDScan(Probe) 3.00 cop[tikv] table:tt2 keep order:false, stats:pseudo" + ], + "Res": [ + "1" + ] + }, + { + "SQL": "select /*+ use_index_merge(tt3) */ 1 from tt3 where c1 < -10 or c2 < 10 and reverse(c3) = '2';", + "Plan": [ + "Projection 5098.44 root 1->Column#5", + "└─Selection 2825.66 root or(lt(test.tt3.c1, -10), and(lt(test.tt3.c2, 10), eq(reverse(cast(test.tt3.c3, var_string(20))), \"2\")))", + " └─IndexMerge 5542.21 root ", + " ├─IndexRangeScan(Build) 3323.33 cop[tikv] table:tt3, index:c1(c1) range:[-inf,-10), keep order:false, stats:pseudo", + " ├─IndexRangeScan(Build) 3323.33 cop[tikv] table:tt3, index:c2(c2) range:[-inf,10), keep order:false, stats:pseudo", + " └─TableRowIDScan(Probe) 5542.21 cop[tikv] table:tt3 keep order:false, stats:pseudo" + ], + "Res": [ + "1" + ] + }, + { + "SQL": "select 1 from t1 where c1 = 'de' or c2 = '10' and from_base64(to_base64(c1)) = 'ab';", + "Plan": [ + "Projection 8000.00 root 1->Column#5", + "└─Selection 8000.00 root or(eq(test.t1.c1, \"de\"), and(eq(test.t1.c2, \"10\"), eq(from_base64(to_base64(test.t1.c1)), \"ab\")))", + " └─TableReader 10000.00 root data:TableFullScan", + " └─TableFullScan 10000.00 cop[tikv] table:t1 keep order:false, stats:pseudo" + ], + "Res": [ + "1" + ] + } + ] + }, { "Name": "TestIndexMergeWithCorrelatedColumns", "Cases": [ @@ -2068,14 +2155,13 @@ " ├─TableReader(Build) 10000.00 root data:TableFullScan", " │ └─TableFullScan 10000.00 cop[tikv] table:t2 keep order:false, stats:pseudo", " └─StreamAgg(Probe) 1.00 root funcs:min(test.t1.c1)->Column#8, funcs:sum(0)->Column#9, funcs:count(1)->Column#10", - " └─Selection 0.01 root substring(cast(test.t1.c3, var_string(20)), 10)", - " └─IndexMerge 0.01 root ", - " ├─Selection(Build) 1.00 cop[tikv] eq(10, test.t2.c3)", - " │ └─TableRangeScan 1.00 cop[tikv] table:t1 range:[10,10], keep order:false, stats:pseudo", - " ├─Selection(Build) 8.00 cop[tikv] eq(1, test.t2.c3)", - " │ └─IndexRangeScan 10.00 cop[tikv] table:t1, index:c2(c2) range:[1,1], keep order:false, stats:pseudo", - " └─Selection(Probe) 0.01 cop[tikv] or(and(eq(test.t1.c1, 10), eq(10, test.t2.c3)), and(eq(test.t1.c2, 1), eq(1, test.t2.c3)))", - " └─TableRowIDScan 9.00 cop[tikv] table:t1 keep order:false, stats:pseudo" + " └─IndexMerge 0.01 root ", + " ├─Selection(Build) 1.00 cop[tikv] eq(10, test.t2.c3)", + " │ └─TableRangeScan 1.00 cop[tikv] table:t1 range:[10,10], keep order:false, stats:pseudo", + " ├─Selection(Build) 8.00 cop[tikv] eq(1, test.t2.c3)", + " │ └─IndexRangeScan 10.00 cop[tikv] table:t1, index:c2(c2) range:[1,1], keep order:false, stats:pseudo", + " └─Selection(Probe) 0.01 cop[tikv] or(and(eq(test.t1.c1, 10), eq(10, test.t2.c3)), and(eq(test.t1.c2, 1), eq(1, test.t2.c3))), substring(cast(test.t1.c3, var_string(20)), 10)", + " └─TableRowIDScan 9.00 cop[tikv] table:t1 keep order:false, stats:pseudo" ], "Res": [ "1 1 1", @@ -2112,14 +2198,13 @@ " ├─TableReader(Build) 10000.00 root data:TableFullScan", " │ └─TableFullScan 10000.00 cop[tikv] table:t2 keep order:false, stats:pseudo", " └─StreamAgg(Probe) 1.00 root funcs:min(test.t1.c1)->Column#8, funcs:sum(0)->Column#9, funcs:count(1)->Column#10", - " └─Selection 3.03 root substring(cast(test.t1.c3, var_string(20)), 10)", - " └─IndexMerge 3.78 root ", - " ├─Selection(Build) 3.33 cop[tikv] eq(test.t1.c1, test.t2.c3)", - " │ └─TableRangeScan 3333.33 cop[tikv] table:t1 range:[10,+inf], keep order:false, stats:pseudo", - " ├─Selection(Build) 8.00 cop[tikv] eq(1, test.t2.c3)", - " │ └─IndexRangeScan 10.00 cop[tikv] table:t1, index:c2(c2) range:[1,1], keep order:false, stats:pseudo", - " └─Selection(Probe) 3.78 cop[tikv] or(and(ge(test.t1.c1, 10), eq(test.t1.c1, test.t2.c3)), and(eq(test.t1.c2, 1), eq(1, test.t2.c3)))", - " └─TableRowIDScan 3338.67 cop[tikv] table:t1 keep order:false, stats:pseudo" + " └─IndexMerge 3.03 root ", + " ├─Selection(Build) 3.33 cop[tikv] eq(test.t1.c1, test.t2.c3)", + " │ └─TableRangeScan 3333.33 cop[tikv] table:t1 range:[10,+inf], keep order:false, stats:pseudo", + " ├─Selection(Build) 8.00 cop[tikv] eq(1, test.t2.c3)", + " │ └─IndexRangeScan 10.00 cop[tikv] table:t1, index:c2(c2) range:[1,1], keep order:false, stats:pseudo", + " └─Selection(Probe) 3.03 cop[tikv] or(and(ge(test.t1.c1, 10), eq(test.t1.c1, test.t2.c3)), and(eq(test.t1.c2, 1), eq(1, test.t2.c3))), substring(cast(test.t1.c3, var_string(20)), 10)", + " └─TableRowIDScan 3338.67 cop[tikv] table:t1 keep order:false, stats:pseudo" ], "Res": [ "1 1 1", diff --git a/session/session.go b/session/session.go index e0cf570c13dc8..098a82eeb5c9f 100644 --- a/session/session.go +++ b/session/session.go @@ -2707,6 +2707,18 @@ func BootstrapSession(store kv.Storage) (*domain.Domain, error) { return nil, err } + // start sub workers for concurrent stats loading + concurrency := config.GetGlobalConfig().Performance.StatsLoadConcurrency + subCtxs := make([]sessionctx.Context, concurrency) + for i := 0; i < int(concurrency); i++ { + subSe, err := createSession(store) + if err != nil { + return nil, err + } + subCtxs[i] = subSe + } + dom.StartLoadStatsSubWorkers(subCtxs) + dom.PlanReplayerLoop() if raw, ok := store.(kv.EtcdBackend); ok { diff --git a/sessionctx/stmtctx/stmtctx.go b/sessionctx/stmtctx/stmtctx.go index b7f84af9ffc38..862a3c7dba6f8 100644 --- a/sessionctx/stmtctx/stmtctx.go +++ b/sessionctx/stmtctx/stmtctx.go @@ -221,6 +221,19 @@ type StatementContext struct { // WeakConsistency is true when read consistency is weak and in a read statement and not in a transaction. WeakConsistency bool + + StatsLoad struct { + // Timeout to wait for sync-load + Timeout time.Duration + // NeededColumns stores the columns whose stats are needed for planner. + NeededColumns []model.TableColumnID + // ResultCh to receive stats loading results + ResultCh chan model.TableColumnID + // Fallback indicates if the planner uses full-loaded stats or fallback all to pseudo/simple. + Fallback bool + // LoadStartTime is to record the load start time to calculate latency + LoadStartTime time.Time + } } // StmtHints are SessionVars related sql hints. diff --git a/sessionctx/variable/session.go b/sessionctx/variable/session.go index d0d34554f2379..73c900b60a4d4 100644 --- a/sessionctx/variable/session.go +++ b/sessionctx/variable/session.go @@ -1007,6 +1007,9 @@ type SessionVars struct { // ReadConsistency indicates the read consistency requirement. ReadConsistency ReadConsistencyLevel + + // StatsLoadSyncWait indicates how long to wait for stats load before timeout. + StatsLoadSyncWait int64 } // InitStatementContext initializes a StatementContext, the object is reused to reduce allocation. @@ -1242,6 +1245,7 @@ func NewSessionVars() *SessionVars { EnablePlacementChecks: DefEnablePlacementCheck, Rng: utilMath.NewWithTime(), StmtStats: stmtstats.CreateStatementStats(), + StatsLoadSyncWait: StatsLoadSyncWait.Load(), } vars.KVVars = tikvstore.NewVariables(&vars.Killed) vars.Concurrency = Concurrency{ diff --git a/sessionctx/variable/sysvar.go b/sessionctx/variable/sysvar.go index ed1f9180de160..564518de6ccd0 100644 --- a/sessionctx/variable/sysvar.go +++ b/sessionctx/variable/sysvar.go @@ -1358,6 +1358,28 @@ var defaultSysVars = []*SysVar{ return nil }, }, + {Scope: ScopeGlobal | ScopeSession, Name: TiDBStatsLoadSyncWait, Value: strconv.Itoa(DefTiDBStatsLoadSyncWait), skipInit: true, Type: TypeInt, MinValue: 0, MaxValue: math.MaxInt32, + SetSession: func(s *SessionVars, val string) error { + s.StatsLoadSyncWait = tidbOptInt64(val, DefTiDBStatsLoadSyncWait) + return nil + }, + GetGlobal: func(s *SessionVars) (string, error) { + return strconv.FormatInt(StatsLoadSyncWait.Load(), 10), nil + }, + SetGlobal: func(s *SessionVars, val string) error { + StatsLoadSyncWait.Store(tidbOptInt64(val, DefTiDBStatsLoadSyncWait)) + return nil + }, + }, + {Scope: ScopeGlobal, Name: TiDBStatsLoadPseudoTimeout, Value: BoolToOnOff(DefTiDBStatsLoadPseudoTimeout), skipInit: true, Type: TypeBool, + GetGlobal: func(s *SessionVars) (string, error) { + return strconv.FormatBool(StatsLoadPseudoTimeout.Load()), nil + }, + SetGlobal: func(s *SessionVars, val string) error { + StatsLoadPseudoTimeout.Store(TiDBOptOn(val)) + return nil + }, + }, } // FeedbackProbability points to the FeedbackProbability in statistics package. diff --git a/sessionctx/variable/tidb_vars.go b/sessionctx/variable/tidb_vars.go index 7d3cd9d632a00..7fa82662db22b 100644 --- a/sessionctx/variable/tidb_vars.go +++ b/sessionctx/variable/tidb_vars.go @@ -606,6 +606,9 @@ const ( // TiDBTmpTableMaxSize indicates the max memory size of temporary tables. TiDBTmpTableMaxSize = "tidb_tmp_table_max_size" + + // TiDBStatsLoadSyncWait indicates the time sql execution will sync-wait for stats load. + TiDBStatsLoadSyncWait = "tidb_stats_load_sync_wait" ) // TiDB vars that have only global scope @@ -633,6 +636,8 @@ const ( // It is used to invalidate the collected predicate columns after turning off TiDBEnableColumnTracking, which avoids physical deletion. // It doesn't have cache in memory and we directly get/set the variable value from/to mysql.tidb. TiDBDisableColumnTrackingTime = "tidb_disable_column_tracking_time" + // TiDBStatsLoadPseudoTimeout indicates whether to fallback to pseudo stats after load timeout. + TiDBStatsLoadPseudoTimeout = "tidb_stats_load_pseudo_timeout" ) // TiDB intentional limits @@ -788,6 +793,8 @@ const ( DefTiDBEnableIndexMerge = true DefTiDBPersistAnalyzeOptions = true DefTiDBEnableColumnTracking = true + DefTiDBStatsLoadSyncWait = 0 + DefTiDBStatsLoadPseudoTimeout = false ) // Process global variables. @@ -818,4 +825,6 @@ var ( RestrictedReadOnly = atomic.NewBool(DefTiDBRestrictedReadOnly) PersistAnalyzeOptions = atomic.NewBool(DefTiDBPersistAnalyzeOptions) EnableColumnTracking = atomic.NewBool(DefTiDBEnableColumnTracking) + StatsLoadSyncWait = atomic.NewInt64(DefTiDBStatsLoadSyncWait) + StatsLoadPseudoTimeout = atomic.NewBool(DefTiDBStatsLoadPseudoTimeout) ) diff --git a/statistics/handle/handle.go b/statistics/handle/handle.go index aa9fa33edeeb9..8aff021344231 100644 --- a/statistics/handle/handle.go +++ b/statistics/handle/handle.go @@ -28,6 +28,7 @@ import ( "github.com/ngaut/pools" "github.com/pingcap/errors" "github.com/pingcap/failpoint" + "github.com/pingcap/tidb/config" "github.com/pingcap/tidb/ddl/util" "github.com/pingcap/tidb/infoschema" "github.com/pingcap/tidb/parser/ast" @@ -118,6 +119,9 @@ type Handle struct { // idxUsageListHead contains all the index usage collectors required by session. idxUsageListHead *SessionIndexUsageCollector + + // statsLoad is used to load stats concurrently + StatsLoad StatsLoad } func (h *Handle) withRestrictedSQLExecutor(ctx context.Context, fn func(context.Context, sqlexec.RestrictedSQLExecutor) ([]chunk.Row, []*ast.ResultField, error)) ([]chunk.Row, []*ast.ResultField, error) { @@ -204,6 +208,7 @@ type sessionPool interface { // NewHandle creates a Handle for update stats. func NewHandle(ctx sessionctx.Context, lease time.Duration, pool sessionPool) (*Handle, error) { + cfg := config.GetGlobalConfig() handle := &Handle{ ddlEventCh: make(chan *util.Event, 100), listHead: &SessionStatsCollector{mapper: make(tableDeltaMap), rateMap: make(errorRateDeltaMap)}, @@ -218,6 +223,10 @@ func NewHandle(ctx sessionctx.Context, lease time.Duration, pool sessionPool) (* handle.globalMap.data = make(tableDeltaMap) handle.feedback.data = statistics.NewQueryFeedbackMap() handle.colMap.data = make(colStatsUsageMap) + handle.StatsLoad.SubCtxs = make([]sessionctx.Context, cfg.Performance.StatsLoadConcurrency) + handle.StatsLoad.NeededColumnsCh = make(chan *NeededColumnTask, cfg.Performance.StatsLoadQueueSize) + handle.StatsLoad.TimeoutColumnsCh = make(chan *NeededColumnTask, cfg.Performance.StatsLoadQueueSize) + handle.StatsLoad.workingColMap = map[model.TableColumnID][]chan model.TableColumnID{} err := handle.RefreshVars() if err != nil { return nil, err @@ -607,13 +616,13 @@ func (sc statsCache) update(tables []*statistics.Table, deletedIDs []int64, newV // LoadNeededHistograms will load histograms for those needed columns. func (h *Handle) LoadNeededHistograms() (err error) { cols := statistics.HistogramNeededColumns.AllCols() - reader, err := h.getStatsReader(0) + reader, err := h.getGlobalStatsReader(0) if err != nil { return err } defer func() { - err1 := h.releaseStatsReader(reader) + err1 := h.releaseGlobalStatsReader(reader) if err1 != nil && err == nil { err = err1 } @@ -872,12 +881,12 @@ func (h *Handle) columnStatsFromStorage(reader *statsReader, row chunk.Row, tabl // TableStatsFromStorage loads table stats info from storage. func (h *Handle) TableStatsFromStorage(tableInfo *model.TableInfo, physicalID int64, loadAll bool, snapshot uint64) (_ *statistics.Table, err error) { - reader, err := h.getStatsReader(snapshot) + reader, err := h.getGlobalStatsReader(snapshot) if err != nil { return nil, err } defer func() { - err1 := h.releaseStatsReader(reader) + err1 := h.releaseGlobalStatsReader(reader) if err == nil && err1 != nil { err = err1 } @@ -977,12 +986,12 @@ func (h *Handle) extendedStatsFromStorage(reader *statsReader, table *statistics // StatsMetaCountAndModifyCount reads count and modify_count for the given table from mysql.stats_meta. func (h *Handle) StatsMetaCountAndModifyCount(tableID int64) (int64, int64, error) { - reader, err := h.getStatsReader(0) + reader, err := h.getGlobalStatsReader(0) if err != nil { return 0, 0, err } defer func() { - err1 := h.releaseStatsReader(reader) + err1 := h.releaseGlobalStatsReader(reader) if err1 != nil && err == nil { err = err1 } @@ -1408,38 +1417,51 @@ func (sr *statsReader) isHistory() bool { return sr.snapshot > 0 } -func (h *Handle) getStatsReader(snapshot uint64) (reader *statsReader, err error) { +func (h *Handle) getGlobalStatsReader(snapshot uint64) (reader *statsReader, err error) { + h.mu.Lock() + defer func() { + if r := recover(); r != nil { + err = fmt.Errorf("getGlobalStatsReader panic %v", r) + } + if err != nil { + h.mu.Unlock() + } + }() + return h.getStatsReader(snapshot, h.mu.ctx.(sqlexec.RestrictedSQLExecutor)) +} + +func (h *Handle) releaseGlobalStatsReader(reader *statsReader) error { + defer h.mu.Unlock() + return h.releaseStatsReader(reader, h.mu.ctx.(sqlexec.RestrictedSQLExecutor)) +} + +func (h *Handle) getStatsReader(snapshot uint64, ctx sqlexec.RestrictedSQLExecutor) (reader *statsReader, err error) { failpoint.Inject("mockGetStatsReaderFail", func(val failpoint.Value) { if val.(bool) { failpoint.Return(nil, errors.New("gofail genStatsReader error")) } }) if snapshot > 0 { - return &statsReader{ctx: h.mu.ctx.(sqlexec.RestrictedSQLExecutor), snapshot: snapshot}, nil + return &statsReader{ctx: ctx, snapshot: snapshot}, nil } - h.mu.Lock() defer func() { if r := recover(); r != nil { err = fmt.Errorf("getStatsReader panic %v", r) } - if err != nil { - h.mu.Unlock() - } }() failpoint.Inject("mockGetStatsReaderPanic", nil) - _, err = h.mu.ctx.(sqlexec.SQLExecutor).ExecuteInternal(context.TODO(), "begin") + _, err = ctx.(sqlexec.SQLExecutor).ExecuteInternal(context.TODO(), "begin") if err != nil { return nil, err } - return &statsReader{ctx: h.mu.ctx.(sqlexec.RestrictedSQLExecutor)}, nil + return &statsReader{ctx: ctx}, nil } -func (h *Handle) releaseStatsReader(reader *statsReader) error { +func (h *Handle) releaseStatsReader(reader *statsReader, ctx sqlexec.RestrictedSQLExecutor) error { if reader.snapshot > 0 { return nil } - _, err := h.mu.ctx.(sqlexec.SQLExecutor).ExecuteInternal(context.TODO(), "commit") - h.mu.Unlock() + _, err := ctx.(sqlexec.SQLExecutor).ExecuteInternal(context.TODO(), "commit") return err } @@ -1585,12 +1607,12 @@ func (h *Handle) removeExtendedStatsItem(tableID int64, statsName string) { // ReloadExtendedStatistics drops the cache for extended statistics and reload data from mysql.stats_extended. func (h *Handle) ReloadExtendedStatistics() error { - reader, err := h.getStatsReader(0) + reader, err := h.getGlobalStatsReader(0) if err != nil { return err } defer func() { - err1 := h.releaseStatsReader(reader) + err1 := h.releaseGlobalStatsReader(reader) if err1 != nil && err == nil { err = err1 } diff --git a/statistics/handle/handle_hist.go b/statistics/handle/handle_hist.go new file mode 100644 index 0000000000000..e0cb34d88f171 --- /dev/null +++ b/statistics/handle/handle_hist.go @@ -0,0 +1,376 @@ +// Copyright 2021 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package handle + +import ( + "runtime" + "sync" + "time" + + "github.com/pingcap/errors" + "github.com/pingcap/tidb/metrics" + "github.com/pingcap/tidb/parser/model" + "github.com/pingcap/tidb/sessionctx" + "github.com/pingcap/tidb/sessionctx/stmtctx" + "github.com/pingcap/tidb/statistics" + "github.com/pingcap/tidb/util/logutil" + "github.com/pingcap/tidb/util/sqlexec" + "go.uber.org/zap" +) + +// StatsLoad is used to load stats concurrently +type StatsLoad struct { + sync.Mutex + SubCtxs []sessionctx.Context + NeededColumnsCh chan *NeededColumnTask + TimeoutColumnsCh chan *NeededColumnTask + workingColMap map[model.TableColumnID][]chan model.TableColumnID +} + +// NeededColumnTask represents one needed column with expire time. +type NeededColumnTask struct { + TableColumnID model.TableColumnID + ToTimeout time.Time + ResultCh chan model.TableColumnID +} + +// SendLoadRequests send neededColumns requests +func (h *Handle) SendLoadRequests(sc *stmtctx.StatementContext, neededColumns []model.TableColumnID, timeout time.Duration) error { + missingColumns := h.genHistMissingColumns(neededColumns) + if len(missingColumns) <= 0 { + return nil + } + sc.StatsLoad.Timeout = timeout + sc.StatsLoad.NeededColumns = missingColumns + sc.StatsLoad.ResultCh = make(chan model.TableColumnID, len(neededColumns)) + for _, col := range missingColumns { + err := h.AppendNeededColumn(col, sc.StatsLoad.ResultCh, timeout) + if err != nil { + return err + } + } + sc.StatsLoad.LoadStartTime = time.Now() + return nil +} + +// SyncWaitStatsLoad sync waits loading of neededColumns and return false if timeout +func (h *Handle) SyncWaitStatsLoad(sc *stmtctx.StatementContext) bool { + if len(sc.StatsLoad.NeededColumns) <= 0 { + return true + } + defer func() { + if sc.StatsLoad.ResultCh != nil { + close(sc.StatsLoad.ResultCh) + } + sc.StatsLoad.NeededColumns = nil + }() + resultCheckMap := map[model.TableColumnID]struct{}{} + for _, col := range sc.StatsLoad.NeededColumns { + resultCheckMap[col] = struct{}{} + } + metrics.SyncLoadCounter.Inc() + timer := time.NewTimer(sc.StatsLoad.Timeout) + defer timer.Stop() + for { + select { + case result, ok := <-sc.StatsLoad.ResultCh: + if ok { + delete(resultCheckMap, result) + if len(resultCheckMap) == 0 { + metrics.SyncLoadHistogram.Observe(float64(time.Since(sc.StatsLoad.LoadStartTime).Milliseconds())) + return true + } + } + case <-timer.C: + metrics.SyncLoadTimeoutCounter.Inc() + return false + } + } +} + +// genHistMissingColumns generates hist-missing columns based on neededColumns and statsCache. +func (h *Handle) genHistMissingColumns(neededColumns []model.TableColumnID) []model.TableColumnID { + statsCache := h.statsCache.Load().(statsCache) + missingColumns := make([]model.TableColumnID, 0, len(neededColumns)) + for _, col := range neededColumns { + tbl, ok := statsCache.tables[col.TableID] + if !ok { + continue + } + colHist, ok := tbl.Columns[col.ColumnID] + if !ok { + continue + } + if colHist.IsHistNeeded(tbl.Pseudo) { + missingColumns = append(missingColumns, col) + } + } + return missingColumns +} + +// AppendNeededColumn appends needed column to ch, if exists, do not append the duplicated one. +func (h *Handle) AppendNeededColumn(c model.TableColumnID, resultCh chan model.TableColumnID, timeout time.Duration) error { + toTimout := time.Now().Local().Add(timeout) + colTask := &NeededColumnTask{TableColumnID: c, ToTimeout: toTimout, ResultCh: resultCh} + return h.writeToChanWithTimeout(h.StatsLoad.NeededColumnsCh, colTask, timeout) +} + +var errExit = errors.New("Stop loading since domain is closed") + +type statsReaderContext struct { + reader *statsReader + createdTime time.Time +} + +// SubLoadWorker loads hist data for each column +func (h *Handle) SubLoadWorker(ctx sessionctx.Context, exit chan struct{}, exitWg *sync.WaitGroup) { + readerCtx := &statsReaderContext{} + defer func() { + exitWg.Done() + logutil.BgLogger().Info("SubLoadWorker exited.") + if readerCtx.reader != nil { + err := h.releaseStatsReader(readerCtx.reader, ctx.(sqlexec.RestrictedSQLExecutor)) + if err != nil { + logutil.BgLogger().Error("Fail to release stats loader: ", zap.Error(err)) + } + } + }() + for { + err := h.handleOneTask(readerCtx, ctx.(sqlexec.RestrictedSQLExecutor), exit) + if err != nil { + switch err { + case errExit: + return + default: + time.Sleep(10 * time.Millisecond) + continue + } + } + } +} + +// handleOneTask handles one column task. +func (h *Handle) handleOneTask(readerCtx *statsReaderContext, ctx sqlexec.RestrictedSQLExecutor, exit chan struct{}) (err error) { + defer func() { + // recover for each task, worker keeps working + if r := recover(); r != nil { + buf := make([]byte, 4096) + stackSize := runtime.Stack(buf, false) + buf = buf[:stackSize] + logutil.BgLogger().Error("stats loading panicked", zap.String("stack", string(buf))) + } + }() + h.getFreshStatsReader(readerCtx, ctx) + task, err := h.drainColTask(exit) + if err != nil { + if err != errExit { + logutil.BgLogger().Error("Fail to drain task for stats loading.", zap.Error(err)) + } + return err + } + col := task.TableColumnID + // to avoid duplicated handling in concurrent scenario + if !h.setWorking(col, task.ResultCh) { + return nil + } + oldCache := h.statsCache.Load().(statsCache) + tbl, ok := oldCache.tables[col.TableID] + if !ok { + task.ResultCh <- col + return nil + } + c, ok := tbl.Columns[col.ColumnID] + if !ok || c.Len() > 0 { + task.ResultCh <- col + return nil + } + t := time.Now() + hist, err := h.readStatsForOne(col, c, readerCtx.reader) + if err != nil { + h.StatsLoad.NeededColumnsCh <- task + return err + } + metrics.ReadStatsHistogram.Observe(float64(time.Since(t).Milliseconds())) + if hist != nil && h.updateCachedColumn(col, hist) { + task.ResultCh <- col + } + h.finishWorking(col) + return nil +} + +func (h *Handle) getFreshStatsReader(readerCtx *statsReaderContext, ctx sqlexec.RestrictedSQLExecutor) { + if readerCtx.reader == nil || readerCtx.createdTime.Add(h.Lease()).Before(time.Now()) { + if readerCtx.reader != nil { + err := h.releaseStatsReader(readerCtx.reader, ctx) + if err != nil { + logutil.BgLogger().Warn("Fail to release stats loader: ", zap.Error(err)) + } + } + for { + newReader, err := h.getStatsReader(0, ctx) + if err != nil { + logutil.BgLogger().Error("Fail to new stats loader, retry after a while.", zap.Error(err)) + time.Sleep(time.Millisecond * 10) + } else { + readerCtx.reader = newReader + readerCtx.createdTime = time.Now() + return + } + } + } else { + return + } +} + +// readStatsForOne reads hist for one column, TODO load data via kv-get asynchronously +func (h *Handle) readStatsForOne(col model.TableColumnID, c *statistics.Column, reader *statsReader) (*statistics.Column, error) { + hg, err := h.histogramFromStorage(reader, col.TableID, c.ID, &c.Info.FieldType, c.Histogram.NDV, 0, c.LastUpdateVersion, c.NullCount, c.TotColSize, c.Correlation) + if err != nil { + return nil, errors.Trace(err) + } + cms, topN, err := h.cmSketchAndTopNFromStorage(reader, col.TableID, 0, col.ColumnID) + if err != nil { + return nil, errors.Trace(err) + } + fms, err := h.fmSketchFromStorage(reader, col.TableID, 0, col.ColumnID) + if err != nil { + return nil, errors.Trace(err) + } + rows, _, err := reader.read("select stats_ver from mysql.stats_histograms where is_index = 0 and table_id = %? and hist_id = %?", col.TableID, col.ColumnID) + if err != nil { + return nil, errors.Trace(err) + } + if len(rows) == 0 { + logutil.BgLogger().Error("fail to get stats version for this histogram", zap.Int64("table_id", col.TableID), zap.Int64("hist_id", col.ColumnID)) + } + colHist := &statistics.Column{ + PhysicalID: col.TableID, + Histogram: *hg, + Info: c.Info, + CMSketch: cms, + TopN: topN, + FMSketch: fms, + IsHandle: c.IsHandle, + StatsVer: rows[0].GetInt64(0), + } + // Column.Count is calculated by Column.TotalRowCount(). Hence, we don't set Column.Count when initializing colHist. + colHist.Count = int64(colHist.TotalRowCount()) + return colHist, nil +} + +// drainColTask will hang until a column task can return. +func (h *Handle) drainColTask(exit chan struct{}) (*NeededColumnTask, error) { + // select NeededColumnsCh firstly, if no task, then select TimeoutColumnsCh + for { + select { + case <-exit: + return nil, errExit + case task, ok := <-h.StatsLoad.NeededColumnsCh: + if !ok { + return nil, errors.New("drainColTask: cannot read from NeededColumnsCh, maybe the chan is closed") + } + // if the task has already timeout, no sql is sync-waiting for it, + // so do not handle it just now, put it to another channel with lower priority + if time.Now().After(task.ToTimeout) { + h.writeToChanNonblocking(h.StatsLoad.TimeoutColumnsCh, task) + continue + } + return task, nil + case task, ok := <-h.StatsLoad.TimeoutColumnsCh: + select { + case <-exit: + return nil, errExit + case task0, ok0 := <-h.StatsLoad.NeededColumnsCh: + if !ok0 { + return nil, errors.New("drainColTask: cannot read from NeededColumnsCh, maybe the chan is closed") + } + // send task back to TimeoutColumnsCh and return the task drained from NeededColumnsCh + h.writeToChanNonblocking(h.StatsLoad.TimeoutColumnsCh, task) + return task0, nil + default: + if !ok { + return nil, errors.New("drainColTask: cannot read from TimeoutColumnsCh, maybe the chan is closed") + } + // NeededColumnsCh is empty now, handle task from TimeoutColumnsCh + return task, nil + } + } + } +} + +// writeToChanNonblocking writes in a nonblocking way, and if the channel queue is full, it's ok to drop the task. +func (h *Handle) writeToChanNonblocking(taskCh chan *NeededColumnTask, task *NeededColumnTask) { + select { + case taskCh <- task: + default: + } +} + +// writeToChanWithTimeout writes a task to a channel and blocks until timeout. +func (h *Handle) writeToChanWithTimeout(taskCh chan *NeededColumnTask, task *NeededColumnTask, timeout time.Duration) error { + timer := time.NewTimer(timeout) + defer timer.Stop() + select { + case taskCh <- task: + case <-timer.C: + return errors.New("Channel is full and timeout writing to channel") + } + return nil +} + +// updateCachedColumn updates the column hist to global statsCache. +func (h *Handle) updateCachedColumn(col model.TableColumnID, colHist *statistics.Column) (updated bool) { + h.StatsLoad.Lock() + defer h.StatsLoad.Unlock() + // Reload the latest stats cache, otherwise the `updateStatsCache` may fail with high probability, because functions + // like `GetPartitionStats` called in `fmSketchFromStorage` would have modified the stats cache already. + oldCache := h.statsCache.Load().(statsCache) + tbl, ok := oldCache.tables[col.TableID] + if !ok { + return true + } + c, ok := tbl.Columns[col.ColumnID] + if !ok || c.Len() > 0 { + return true + } + tbl = tbl.Copy() + tbl.Columns[c.ID] = colHist + return h.updateStatsCache(oldCache.update([]*statistics.Table{tbl}, nil, oldCache.version)) +} + +func (h *Handle) setWorking(col model.TableColumnID, resultCh chan model.TableColumnID) bool { + h.StatsLoad.Lock() + defer h.StatsLoad.Unlock() + chList, ok := h.StatsLoad.workingColMap[col] + if ok { + h.StatsLoad.workingColMap[col] = append(chList, resultCh) + return false + } + chList = []chan model.TableColumnID{} + chList = append(chList, resultCh) + h.StatsLoad.workingColMap[col] = chList + return true +} + +func (h *Handle) finishWorking(col model.TableColumnID) { + h.StatsLoad.Lock() + defer h.StatsLoad.Unlock() + if chList, ok := h.StatsLoad.workingColMap[col]; ok { + for _, ch := range chList { + ch <- col + } + } + delete(h.StatsLoad.workingColMap, col) +} diff --git a/statistics/handle/handle_hist_test.go b/statistics/handle/handle_hist_test.go new file mode 100644 index 0000000000000..4443a7d36e2ab --- /dev/null +++ b/statistics/handle/handle_hist_test.go @@ -0,0 +1,128 @@ +// Copyright 2021 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package handle_test + +import ( + "time" + + "github.com/cznic/mathutil" + . "github.com/pingcap/check" + "github.com/pingcap/tidb/parser/model" + "github.com/pingcap/tidb/sessionctx/stmtctx" + "github.com/pingcap/tidb/util/testkit" +) + +var _ = Suite(&testLoadHistSuite{}) + +type testLoadHistSuite struct { + testSuiteBase +} + +func (s *testLoadHistSuite) TestConcurrentLoadHist(c *C) { + defer cleanEnv(c, s.store, s.do) + testKit := testkit.NewTestKit(c, s.store) + testKit.MustExec("use test") + testKit.MustExec("drop table if exists t") + testKit.MustExec("set @@session.tidb_analyze_version=2") + testKit.MustExec("create table t(a int, b int, c int, primary key(a), key idx(b))") + testKit.MustExec("insert into t values (1,1,1),(2,2,2),(3,3,3)") + + oriLease := s.do.StatsHandle().Lease() + s.do.StatsHandle().SetLease(1) + defer func() { + s.do.StatsHandle().SetLease(oriLease) + }() + testKit.MustExec("analyze table t") + + is := s.do.InfoSchema() + tbl, err := is.TableByName(model.NewCIStr("test"), model.NewCIStr("t")) + c.Assert(err, IsNil) + tableInfo := tbl.Meta() + h := s.do.StatsHandle() + stat := h.GetTableStats(tableInfo) + hg := stat.Columns[tableInfo.Columns[0].ID].Histogram + topn := stat.Columns[tableInfo.Columns[0].ID].TopN + c.Assert(hg.Len()+topn.Num(), Greater, 0) + hg = stat.Columns[tableInfo.Columns[2].ID].Histogram + topn = stat.Columns[tableInfo.Columns[2].ID].TopN + c.Assert(hg.Len()+topn.Num(), Equals, 0) + stmtCtx := &stmtctx.StatementContext{} + neededColumns := make([]model.TableColumnID, 0, len(tableInfo.Columns)) + for _, col := range tableInfo.Columns { + neededColumns = append(neededColumns, model.TableColumnID{TableID: tableInfo.ID, ColumnID: col.ID}) + } + timeout := time.Nanosecond * mathutil.MaxInt + h.SendLoadRequests(stmtCtx, neededColumns, timeout) + rs := h.SyncWaitStatsLoad(stmtCtx) + c.Assert(rs, Equals, true) + stat = h.GetTableStats(tableInfo) + hg = stat.Columns[tableInfo.Columns[2].ID].Histogram + topn = stat.Columns[tableInfo.Columns[2].ID].TopN + c.Assert(hg.Len()+topn.Num(), Greater, 0) +} + +func (s *testLoadHistSuite) TestConcurrentLoadHistTimeout(c *C) { + defer cleanEnv(c, s.store, s.do) + testKit := testkit.NewTestKit(c, s.store) + testKit.MustExec("use test") + testKit.MustExec("drop table if exists t") + testKit.MustExec("set @@session.tidb_analyze_version=2") + testKit.MustExec("set @@session.tidb_stats_load_sync_wait =9999999") + testKit.MustExec("create table t(a int, b int, c int, primary key(a), key idx(b))") + testKit.MustExec("insert into t values (1,1,1),(2,2,2),(3,3,3)") + + oriLease := s.do.StatsHandle().Lease() + s.do.StatsHandle().SetLease(1) + defer func() { + s.do.StatsHandle().SetLease(oriLease) + }() + testKit.MustExec("analyze table t") + + is := s.do.InfoSchema() + tbl, err := is.TableByName(model.NewCIStr("test"), model.NewCIStr("t")) + c.Assert(err, IsNil) + tableInfo := tbl.Meta() + h := s.do.StatsHandle() + stat := h.GetTableStats(tableInfo) + hg := stat.Columns[tableInfo.Columns[0].ID].Histogram + topn := stat.Columns[tableInfo.Columns[0].ID].TopN + c.Assert(hg.Len()+topn.Num(), Greater, 0) + hg = stat.Columns[tableInfo.Columns[2].ID].Histogram + topn = stat.Columns[tableInfo.Columns[2].ID].TopN + c.Assert(hg.Len()+topn.Num(), Equals, 0) + stmtCtx := &stmtctx.StatementContext{} + neededColumns := make([]model.TableColumnID, 0, len(tableInfo.Columns)) + for _, col := range tableInfo.Columns { + neededColumns = append(neededColumns, model.TableColumnID{TableID: tableInfo.ID, ColumnID: col.ID}) + } + h.SendLoadRequests(stmtCtx, neededColumns, 0) // set timeout to 0 so task will go to timeout channel + rs := h.SyncWaitStatsLoad(stmtCtx) + c.Assert(rs, Equals, false) + stat = h.GetTableStats(tableInfo) + hg = stat.Columns[tableInfo.Columns[2].ID].Histogram + topn = stat.Columns[tableInfo.Columns[2].ID].TopN + c.Assert(hg.Len()+topn.Num(), Equals, 0) + // wait for timeout task to be handled + for { + time.Sleep(time.Millisecond * 100) + if len(h.StatsLoad.TimeoutColumnsCh)+len(h.StatsLoad.NeededColumnsCh) == 0 { + break + } + } + stat = h.GetTableStats(tableInfo) + hg = stat.Columns[tableInfo.Columns[2].ID].Histogram + topn = stat.Columns[tableInfo.Columns[2].ID].TopN + c.Assert(hg.Len()+topn.Num(), Greater, 0) +} diff --git a/statistics/handle/handle_test.go b/statistics/handle/handle_test.go index 4cb91b4928501..d99d3eef75b7e 100644 --- a/statistics/handle/handle_test.go +++ b/statistics/handle/handle_test.go @@ -583,6 +583,8 @@ func (s *testStatsSuite) TestLoadStats(c *C) { defer cleanEnv(c, s.store, s.do) testKit := testkit.NewTestKit(c, s.store) testKit.MustExec("use test") + testKit.MustExec("drop table if exists t") + testKit.MustExec("set @@session.tidb_analyze_version=1") testKit.MustExec("create table t(a int, b int, c int, primary key(a), key idx(b))") testKit.MustExec("insert into t values (1,1,1),(2,2,2),(3,3,3)") diff --git a/statistics/handle/update_test.go b/statistics/handle/update_test.go index 6552d547fb4f3..b04447c20baa0 100644 --- a/statistics/handle/update_test.go +++ b/statistics/handle/update_test.go @@ -1465,6 +1465,9 @@ func (s *testStatsSuite) TestLogDetailedInfo(c *C) { testKit := testkit.NewTestKit(c, s.store) testKit.MustExec("use test") + testKit.MustExec("drop table if exists t") + testKit.MustExec("set @@session.tidb_analyze_version=1") + testKit.MustExec("set @@session.tidb_stats_load_sync_wait =0") testKit.MustExec("create table t (a bigint(64), b bigint(64), c bigint(64), primary key(a), index idx(b), index idx_ba(b,a), index idx_bc(b,c))") for i := 0; i < 20; i++ { testKit.MustExec(fmt.Sprintf("insert into t values (%d, %d, %d)", i, i, i)) diff --git a/statistics/histogram.go b/statistics/histogram.go index cd053ec070997..a79cb2a32374f 100644 --- a/statistics/histogram.go +++ b/statistics/histogram.go @@ -19,6 +19,7 @@ import ( "fmt" "math" "sort" + "strconv" "strings" "time" "unsafe" @@ -1074,12 +1075,27 @@ func (c *Column) IsInvalid(sctx sessionctx.Context, collPseudo bool) bool { if collPseudo && c.NotAccurate() { return true } - if c.Histogram.NDV > 0 && c.notNullCount() == 0 && sctx != nil && sctx.GetSessionVars().StmtCtx != nil { - HistogramNeededColumns.insert(tableColumnID{TableID: c.PhysicalID, ColumnID: c.Info.ID}) + if sctx != nil { + stmtctx := sctx.GetSessionVars().StmtCtx + if stmtctx != nil && stmtctx.StatsLoad.Fallback { + return true + } + if c.Histogram.NDV > 0 && c.notNullCount() == 0 && stmtctx != nil { + if stmtctx.StatsLoad.Timeout > 0 { + logutil.BgLogger().Warn("Hist for column should already be loaded as sync but not found.", + zap.String(strconv.FormatInt(c.Info.ID, 10), c.Info.Name.O)) + } + HistogramNeededColumns.insert(tableColumnID{TableID: c.PhysicalID, ColumnID: c.Info.ID}) + } } return c.TotalRowCount() == 0 || (c.Histogram.NDV > 0 && c.notNullCount() == 0) } +// IsHistNeeded checks if this column needs histogram to be loaded +func (c *Column) IsHistNeeded(collPseudo bool) bool { + return (!collPseudo || !c.NotAccurate()) && c.Histogram.NDV > 0 && c.notNullCount() == 0 +} + func (c *Column) equalRowCount(sctx sessionctx.Context, val types.Datum, encodedVal []byte, realtimeRowCount int64) (float64, error) { if val.IsNull() { return float64(c.NullCount), nil