diff --git a/DEPS.bzl b/DEPS.bzl index dca72eca347aa..c9d495cb6e17e 100644 --- a/DEPS.bzl +++ b/DEPS.bzl @@ -760,8 +760,8 @@ def go_deps(): name = "com_github_dgraph_io_ristretto", build_file_proto_mode = "disable_global", importpath = "github.com/dgraph-io/ristretto", - sum = "h1:Wrc3UKTS+cffkOx0xRGFC+ZesNuTfn0ThvEC72N0krk=", - version = "v0.1.1-0.20220403145359-8e850b710d6d", + sum = "h1:6CWw5tJNgpegArSHpNHJKldNeq03FQCwYvfMVWajOK8=", + version = "v0.1.1", ) go_repository( name = "com_github_dgrijalva_jwt_go", @@ -4434,8 +4434,8 @@ def go_deps(): name = "org_golang_x_sys", build_file_proto_mode = "disable_global", importpath = "golang.org/x/sys", - sum = "h1:kunALQeHf1/185U1i0GOB/fy1IPRDDpuoOOqRReG57U=", - version = "v0.1.0", + sum = "h1:ljd4t30dBnAvMZaQCevtY0xLLD0A+bRZXbgLMLU1F/A=", + version = "v0.2.0", ) go_repository( name = "org_golang_x_term", diff --git a/ddl/backfilling.go b/ddl/backfilling.go index 35abc16cd1a6f..0c06fa0fab551 100644 --- a/ddl/backfilling.go +++ b/ddl/backfilling.go @@ -744,9 +744,9 @@ func (b *backfillScheduler) initCopReqSenderPool() { logutil.BgLogger().Warn("[ddl-ingest] cannot init cop request sender", zap.Error(err)) return } - copCtx := newCopContext(b.tbl.Meta(), indexInfo, sessCtx) - if copCtx == nil { - logutil.BgLogger().Warn("[ddl-ingest] cannot init cop request sender") + copCtx, err := newCopContext(b.tbl.Meta(), indexInfo, sessCtx) + if err != nil { + logutil.BgLogger().Warn("[ddl-ingest] cannot init cop request sender", zap.Error(err)) return } ver, err := sessCtx.GetStore().CurrentVersion(kv.GlobalTxnScope) diff --git a/ddl/index_cop.go b/ddl/index_cop.go index c1229ae1f7a1e..c5c2476ade5a5 100644 --- a/ddl/index_cop.go +++ b/ddl/index_cop.go @@ -21,13 +21,14 @@ import ( "github.com/pingcap/errors" "github.com/pingcap/tidb/distsql" + "github.com/pingcap/tidb/expression" "github.com/pingcap/tidb/kv" "github.com/pingcap/tidb/parser/model" - "github.com/pingcap/tidb/parser/mysql" "github.com/pingcap/tidb/sessionctx" "github.com/pingcap/tidb/sessionctx/stmtctx" "github.com/pingcap/tidb/sessionctx/variable" "github.com/pingcap/tidb/statistics" + "github.com/pingcap/tidb/table" "github.com/pingcap/tidb/table/tables" "github.com/pingcap/tidb/tablecodec" "github.com/pingcap/tidb/types" @@ -221,34 +222,145 @@ type copContext struct { colInfos []*model.ColumnInfo fieldTps []*types.FieldType sessCtx sessionctx.Context + + expColInfos []*expression.Column + idxColOutputOffsets []int + handleOutputOffsets []int + virtualColOffsets []int + virtualColFieldTps []*types.FieldType } -func newCopContext(tblInfo *model.TableInfo, idxInfo *model.IndexInfo, sessCtx sessionctx.Context) *copContext { +func newCopContext(tblInfo *model.TableInfo, idxInfo *model.IndexInfo, sessCtx sessionctx.Context) (*copContext, error) { + var err error + usedColumnIDs := make(map[int64]struct{}, len(idxInfo.Columns)) + usedColumnIDs, err = fillUsedColumns(usedColumnIDs, idxInfo, tblInfo) + var handleIDs []int64 + if err != nil { + return nil, err + } + var primaryIdx *model.IndexInfo + if tblInfo.PKIsHandle { + pkCol := tblInfo.GetPkColInfo() + usedColumnIDs[pkCol.ID] = struct{}{} + handleIDs = []int64{pkCol.ID} + } else if tblInfo.IsCommonHandle { + primaryIdx = tables.FindPrimaryIndex(tblInfo) + handleIDs = make([]int64, 0, len(primaryIdx.Columns)) + for _, pkCol := range primaryIdx.Columns { + col := tblInfo.Columns[pkCol.Offset] + handleIDs = append(handleIDs, col.ID) + } + usedColumnIDs, err = fillUsedColumns(usedColumnIDs, primaryIdx, tblInfo) + if err != nil { + return nil, err + } + } + + // Only collect the columns that are used by the index. colInfos := make([]*model.ColumnInfo, 0, len(idxInfo.Columns)) fieldTps := make([]*types.FieldType, 0, len(idxInfo.Columns)) - for _, idxCol := range idxInfo.Columns { - c := tblInfo.Columns[idxCol.Offset] - if c.IsGenerated() && !c.GeneratedStored { - // TODO(tangenta): support reading virtual generated columns. - return nil + for i := range tblInfo.Columns { + col := tblInfo.Columns[i] + if _, found := usedColumnIDs[col.ID]; found { + colInfos = append(colInfos, col) + fieldTps = append(fieldTps, &col.FieldType) } - colInfos = append(colInfos, c) - fieldTps = append(fieldTps, &c.FieldType) } - pkColInfos, pkFieldTps, pkInfo := buildHandleColInfoAndFieldTypes(tblInfo) - colInfos = append(colInfos, pkColInfos...) - fieldTps = append(fieldTps, pkFieldTps...) + // Append the extra handle column when _tidb_rowid is used. + if !tblInfo.HasClusteredIndex() { + extra := model.NewExtraHandleColInfo() + colInfos = append(colInfos, extra) + fieldTps = append(fieldTps, &extra.FieldType) + handleIDs = []int64{extra.ID} + } + + expColInfos, _, err := expression.ColumnInfos2ColumnsAndNames(sessCtx, + model.CIStr{} /* unused */, tblInfo.Name, colInfos, tblInfo) + if err != nil { + return nil, err + } + idxOffsets := resolveIndicesForIndex(expColInfos, idxInfo, tblInfo) + hdColOffsets := resolveIndicesForHandle(expColInfos, handleIDs) + vColOffsets, vColFts := collectVirtualColumnOffsetsAndTypes(expColInfos) copCtx := &copContext{ tblInfo: tblInfo, idxInfo: idxInfo, - pkInfo: pkInfo, + pkInfo: primaryIdx, colInfos: colInfos, fieldTps: fieldTps, sessCtx: sessCtx, + + expColInfos: expColInfos, + idxColOutputOffsets: idxOffsets, + handleOutputOffsets: hdColOffsets, + virtualColOffsets: vColOffsets, + virtualColFieldTps: vColFts, + } + return copCtx, nil +} + +func fillUsedColumns(usedCols map[int64]struct{}, idxInfo *model.IndexInfo, tblInfo *model.TableInfo) (map[int64]struct{}, error) { + colsToChecks := make([]*model.ColumnInfo, 0, len(idxInfo.Columns)) + for _, idxCol := range idxInfo.Columns { + colsToChecks = append(colsToChecks, tblInfo.Columns[idxCol.Offset]) + } + for len(colsToChecks) > 0 { + next := colsToChecks[0] + colsToChecks = colsToChecks[1:] + usedCols[next.ID] = struct{}{} + for depColName := range next.Dependences { + // Expand the virtual generated columns. + depCol := model.FindColumnInfo(tblInfo.Columns, depColName) + if depCol == nil { + return nil, errors.Trace(errors.Errorf("dependent column %s not found", depColName)) + } + if _, ok := usedCols[depCol.ID]; !ok { + colsToChecks = append(colsToChecks, depCol) + } + } } - return copCtx + return usedCols, nil +} + +func resolveIndicesForIndex(outputCols []*expression.Column, idxInfo *model.IndexInfo, tblInfo *model.TableInfo) []int { + offsets := make([]int, 0, len(idxInfo.Columns)) + for _, idxCol := range idxInfo.Columns { + hid := tblInfo.Columns[idxCol.Offset].ID + for j, col := range outputCols { + if col.ID == hid { + offsets = append(offsets, j) + break + } + } + } + return offsets +} + +func resolveIndicesForHandle(cols []*expression.Column, handleIDs []int64) []int { + offsets := make([]int, 0, len(handleIDs)) + for _, hid := range handleIDs { + for j, col := range cols { + if col.ID == hid { + offsets = append(offsets, j) + break + } + } + } + return offsets +} + +func collectVirtualColumnOffsetsAndTypes(cols []*expression.Column) ([]int, []*types.FieldType) { + var offsets []int + var fts []*types.FieldType + for i, col := range cols { + if col.VirtualExpr != nil { + offsets = append(offsets, i) + fts = append(fts, col.GetType()) + } + } + return offsets, fts } func (c *copContext) buildTableScan(ctx context.Context, startTS uint64, start, end kv.Key) (distsql.SelectResult, error) { @@ -284,8 +396,13 @@ func (c *copContext) fetchTableScanResult(ctx context.Context, result distsql.Se return buf, true, nil } iter := chunk.NewIterator4Chunk(chk) + err = table.FillVirtualColumnValue(c.virtualColFieldTps, c.virtualColOffsets, c.expColInfos, c.colInfos, c.sessCtx, chk) + if err != nil { + return nil, false, errors.Trace(err) + } for row := iter.Begin(); row != iter.End(); row = iter.Next() { - idxDt, hdDt := extractIdxValsAndHandle(row, c.idxInfo, c.fieldTps) + idxDt := extractDatumByOffsets(row, c.idxColOutputOffsets, c.expColInfos) + hdDt := extractDatumByOffsets(row, c.handleOutputOffsets, c.expColInfos) handle, err := buildHandle(hdDt, c.tblInfo, c.pkInfo, sctx) if err != nil { return nil, false, errors.Trace(err) @@ -321,34 +438,13 @@ func constructTableScanPB(sCtx sessionctx.Context, tblInfo *model.TableInfo, col return &tipb.Executor{Tp: tipb.ExecType_TypeTableScan, TblScan: tblScan}, err } -func buildHandleColInfoAndFieldTypes(tbInfo *model.TableInfo) ([]*model.ColumnInfo, []*types.FieldType, *model.IndexInfo) { - if tbInfo.PKIsHandle { - for i := range tbInfo.Columns { - if mysql.HasPriKeyFlag(tbInfo.Columns[i].GetFlag()) { - return []*model.ColumnInfo{tbInfo.Columns[i]}, []*types.FieldType{&tbInfo.Columns[i].FieldType}, nil - } - } - } else if tbInfo.IsCommonHandle { - primaryIdx := tables.FindPrimaryIndex(tbInfo) - pkCols := make([]*model.ColumnInfo, 0, len(primaryIdx.Columns)) - pkFts := make([]*types.FieldType, 0, len(primaryIdx.Columns)) - for _, pkCol := range primaryIdx.Columns { - pkCols = append(pkCols, tbInfo.Columns[pkCol.Offset]) - pkFts = append(pkFts, &tbInfo.Columns[pkCol.Offset].FieldType) - } - return pkCols, pkFts, primaryIdx - } - extra := model.NewExtraHandleColInfo() - return []*model.ColumnInfo{extra}, []*types.FieldType{&extra.FieldType}, nil -} - -func extractIdxValsAndHandle(row chunk.Row, idxInfo *model.IndexInfo, fieldTps []*types.FieldType) ([]types.Datum, []types.Datum) { - datumBuf := make([]types.Datum, 0, len(fieldTps)) - idxColLen := len(idxInfo.Columns) - for i, ft := range fieldTps { - datumBuf = append(datumBuf, row.GetDatum(i, ft)) +func extractDatumByOffsets(row chunk.Row, offsets []int, expCols []*expression.Column) []types.Datum { + datumBuf := make([]types.Datum, 0, len(offsets)) + for _, offset := range offsets { + c := expCols[offset] + datumBuf = append(datumBuf, row.GetDatum(offset, c.GetType())) } - return datumBuf[:idxColLen], datumBuf[idxColLen:] + return datumBuf } func buildHandle(pkDts []types.Datum, tblInfo *model.TableInfo, diff --git a/ddl/index_cop_test.go b/ddl/index_cop_test.go index 333afa997d3bc..56bdc9297d95c 100644 --- a/ddl/index_cop_test.go +++ b/ddl/index_cop_test.go @@ -37,7 +37,8 @@ func TestAddIndexFetchRowsFromCoprocessor(t *testing.T) { require.NoError(t, err) tblInfo := tbl.Meta() idxInfo := tblInfo.FindIndexByName(idx) - copCtx := ddl.NewCopContext4Test(tblInfo, idxInfo, tk.Session()) + copCtx, err := ddl.NewCopContext4Test(tblInfo, idxInfo, tk.Session()) + require.NoError(t, err) startKey := tbl.RecordPrefix() endKey := startKey.PrefixNext() txn, err := store.Begin() diff --git a/ddl/ingest/config.go b/ddl/ingest/config.go index 3a96e8ae5201b..e9c1458b1ab0a 100644 --- a/ddl/ingest/config.go +++ b/ddl/ingest/config.go @@ -16,6 +16,7 @@ package ingest import ( "path/filepath" + "sync/atomic" "github.com/pingcap/tidb/br/pkg/lightning/backend" "github.com/pingcap/tidb/br/pkg/lightning/checkpoints" @@ -26,12 +27,18 @@ import ( "go.uber.org/zap" ) +// ImporterRangeConcurrencyForTest is only used for test. +var ImporterRangeConcurrencyForTest *atomic.Int32 + func generateLightningConfig(memRoot MemRoot, jobID int64, unique bool) (*config.Config, error) { tidbCfg := tidbconf.GetGlobalConfig() cfg := config.NewConfig() cfg.TikvImporter.Backend = config.BackendLocal // Each backend will build a single dir in lightning dir. cfg.TikvImporter.SortedKVDir = filepath.Join(LitSortPath, encodeBackendTag(jobID)) + if ImporterRangeConcurrencyForTest != nil { + cfg.TikvImporter.RangeConcurrency = int(ImporterRangeConcurrencyForTest.Load()) + } _, err := cfg.AdjustCommon() if err != nil { logutil.BgLogger().Warn(LitWarnConfigError, zap.Error(err)) diff --git a/errors.toml b/errors.toml index 31a56ef6b1d17..a50e484a5a29b 100644 --- a/errors.toml +++ b/errors.toml @@ -1451,6 +1451,11 @@ error = ''' SET PASSWORD has no significance for user '%-.48s'@'%-.255s' as authentication plugin does not support it. ''' +["executor:1819"] +error = ''' +Your password does not satisfy the current policy requirements +''' + ["executor:1827"] error = ''' The password hash doesn't have the expected format. Check if the correct password algorithm is being used with the PASSWORD() function. diff --git a/executor/BUILD.bazel b/executor/BUILD.bazel index 6a300dbeaf654..cf91360b17a60 100644 --- a/executor/BUILD.bazel +++ b/executor/BUILD.bazel @@ -177,6 +177,7 @@ go_library( "//util/mathutil", "//util/memory", "//util/mvmap", + "//util/password-validation", "//util/pdapi", "//util/plancodec", "//util/printer", diff --git a/executor/analyze.go b/executor/analyze.go index da74d8248a90c..f08f1ad932a9c 100644 --- a/executor/analyze.go +++ b/executor/analyze.go @@ -325,8 +325,9 @@ func (e *AnalyzeExec) handleResultsError(ctx context.Context, concurrency int, n handleGlobalStats(needGlobalStats, globalStatsMap, results) if err1 := statsHandle.SaveTableStatsToStorage(results, e.ctx.GetSessionVars().EnableAnalyzeSnapshot); err1 != nil { + tableID := results.TableID.TableID err = err1 - logutil.Logger(ctx).Error("save table stats to storage failed", zap.Error(err)) + logutil.Logger(ctx).Error("save table stats to storage failed", zap.Error(err), zap.Int64("tableID", tableID)) finishJobWithLog(e.ctx, results.Job, err) } else { finishJobWithLog(e.ctx, results.Job, nil) diff --git a/executor/analyze_col_v2.go b/executor/analyze_col_v2.go index 0ab8a3a019cbe..879d47e85a88e 100644 --- a/executor/analyze_col_v2.go +++ b/executor/analyze_col_v2.go @@ -200,6 +200,27 @@ func (e *AnalyzeColumnsExecV2) decodeSampleDataWithVirtualColumn( return nil } +func printAnalyzeMergeCollectorLog(oldRootCount, newRootCount, subCount, tableID, partitionID int64, isPartition bool, info string, index int) { + if index < 0 { + logutil.BgLogger().Debug(info, + zap.Int64("tableID", tableID), + zap.Int64("partitionID", partitionID), + zap.Bool("isPartitionTable", isPartition), + zap.Int64("oldRootCount", oldRootCount), + zap.Int64("newRootCount", newRootCount), + zap.Int64("subCount", subCount)) + } else { + logutil.BgLogger().Debug(info, + zap.Int64("tableID", tableID), + zap.Int64("partitionID", partitionID), + zap.Bool("isPartitionTable", isPartition), + zap.Int64("oldRootCount", oldRootCount), + zap.Int64("newRootCount", newRootCount), + zap.Int64("subCount", subCount), + zap.Int("subCollectorIndex", index)) + } +} + func (e *AnalyzeColumnsExecV2) buildSamplingStats( ranges []*ranger.Range, needExtStats bool, @@ -236,7 +257,7 @@ func (e *AnalyzeColumnsExecV2) buildSamplingStats( e.samplingMergeWg = &util.WaitGroupWrapper{} e.samplingMergeWg.Add(statsConcurrency) for i := 0; i < statsConcurrency; i++ { - go e.subMergeWorker(mergeResultCh, mergeTaskCh, l, i == 0) + go e.subMergeWorker(mergeResultCh, mergeTaskCh, l, i) } if err = readDataAndSendTask(e.ctx, e.resultHandler, mergeTaskCh, e.memTracker); err != nil { return 0, nil, nil, nil, nil, getAnalyzePanicErr(err) @@ -256,7 +277,12 @@ func (e *AnalyzeColumnsExecV2) buildSamplingStats( continue } oldRootCollectorSize := rootRowCollector.Base().MemSize + oldRootCollectorCount := rootRowCollector.Base().Count rootRowCollector.MergeCollector(mergeResult.collector) + newRootCollectorCount := rootRowCollector.Base().Count + printAnalyzeMergeCollectorLog(oldRootCollectorCount, newRootCollectorCount, + mergeResult.collector.Base().Count, e.tableID.TableID, e.tableID.PartitionID, e.tableID.IsPartitionTable(), + "merge subMergeWorker in AnalyzeColumnsExecV2", -1) e.memTracker.Consume(rootRowCollector.Base().MemSize - oldRootCollectorSize - mergeResult.collector.Base().MemSize) } defer e.memTracker.Release(rootRowCollector.Base().MemSize) @@ -545,7 +571,8 @@ func (e *AnalyzeColumnsExecV2) buildSubIndexJobForSpecialIndex(indexInfos []*mod return tasks } -func (e *AnalyzeColumnsExecV2) subMergeWorker(resultCh chan<- *samplingMergeResult, taskCh <-chan []byte, l int, isClosedChanThread bool) { +func (e *AnalyzeColumnsExecV2) subMergeWorker(resultCh chan<- *samplingMergeResult, taskCh <-chan []byte, l int, index int) { + isClosedChanThread := index == 0 defer func() { if r := recover(); r != nil { logutil.BgLogger().Error("analyze worker panicked", zap.Any("recover", r), zap.Stack("stack")) @@ -590,7 +617,12 @@ func (e *AnalyzeColumnsExecV2) subMergeWorker(resultCh chan<- *samplingMergeResu subCollector.Base().FromProto(colResp.RowCollector, e.memTracker) UpdateAnalyzeJob(e.ctx, e.job, subCollector.Base().Count) oldRetCollectorSize := retCollector.Base().MemSize + oldRetCollectorCount := retCollector.Base().Count retCollector.MergeCollector(subCollector) + newRetCollectorCount := retCollector.Base().Count + printAnalyzeMergeCollectorLog(oldRetCollectorCount, newRetCollectorCount, subCollector.Base().Count, + e.tableID.TableID, e.tableID.PartitionID, e.TableID.IsPartitionTable(), + "merge subCollector in concurrency in AnalyzeColumnsExecV2", index) newRetCollectorSize := retCollector.Base().MemSize subCollectorSize := subCollector.Base().MemSize e.memTracker.Consume(newRetCollectorSize - oldRetCollectorSize - subCollectorSize) diff --git a/executor/analyze_global_stats.go b/executor/analyze_global_stats.go index 961d41dea059d..46e9fdbf41544 100644 --- a/executor/analyze_global_stats.go +++ b/executor/analyze_global_stats.go @@ -73,7 +73,8 @@ func (e *AnalyzeExec) handleGlobalStats(ctx context.Context, needGlobalStats boo globalStatsID.tableID, info.isIndex, info.histIDs, tableAllPartitionStats) if err != nil { - logutil.BgLogger().Error("merge global stats failed", zap.String("info", job.JobInfo), zap.Error(err)) + logutil.BgLogger().Error("merge global stats failed", + zap.String("info", job.JobInfo), zap.Error(err), zap.Int64("tableID", tableID)) if types.ErrPartitionStatsMissing.Equal(err) || types.ErrPartitionColumnStatsMissing.Equal(err) { // When we find some partition-level stats are missing, we need to report warning. e.ctx.GetSessionVars().StmtCtx.AppendWarning(err) @@ -95,7 +96,8 @@ func (e *AnalyzeExec) handleGlobalStats(ctx context.Context, needGlobalStats boo true, ) if err != nil { - logutil.Logger(ctx).Error("save global-level stats to storage failed", zap.String("info", job.JobInfo), zap.Int64("histID", hg.ID), zap.Error(err)) + logutil.Logger(ctx).Error("save global-level stats to storage failed", zap.String("info", job.JobInfo), + zap.Int64("histID", hg.ID), zap.Error(err), zap.Int64("tableID", tableID)) } // Dump stats to historical storage. if err1 := recordHistoricalStats(e.ctx, globalStatsID.tableID); err1 != nil { diff --git a/executor/set_test.go b/executor/set_test.go index 697209d64836a..a4a54a37a3595 100644 --- a/executor/set_test.go +++ b/executor/set_test.go @@ -853,6 +853,14 @@ func TestSetVar(t *testing.T) { tk.MustQuery("select @@global.tidb_opt_range_max_size").Check(testkit.Rows("1048576")) tk.MustExec("set session tidb_opt_range_max_size = 2097152") tk.MustQuery("select @@session.tidb_opt_range_max_size").Check(testkit.Rows("2097152")) + + // test for password validation + tk.MustQuery("SELECT @@GLOBAL.validate_password.enable").Check(testkit.Rows("0")) + tk.MustQuery("SELECT @@GLOBAL.validate_password.length").Check(testkit.Rows("8")) + tk.MustExec("SET GLOBAL validate_password.length = 3") + tk.MustQuery("SELECT @@GLOBAL.validate_password.length").Check(testkit.Rows("4")) + tk.MustExec("SET GLOBAL validate_password.mixed_case_count = 2") + tk.MustQuery("SELECT @@GLOBAL.validate_password.length").Check(testkit.Rows("6")) } func TestGetSetNoopVars(t *testing.T) { @@ -1407,14 +1415,11 @@ func TestValidateSetVar(t *testing.T) { tk.MustExec("set @@innodb_lock_wait_timeout = 1073741825") tk.MustQuery("show warnings").Check(testkit.RowsWithSep("|", "Warning|1292|Truncated incorrect innodb_lock_wait_timeout value: '1073741825'")) - tk.MustExec("set @@global.validate_password_number_count=-1") - tk.MustQuery("show warnings").Check(testkit.RowsWithSep("|", "Warning|1292|Truncated incorrect validate_password_number_count value: '-1'")) - - tk.MustExec("set @@global.validate_password_length=-1") - tk.MustQuery("show warnings").Check(testkit.RowsWithSep("|", "Warning|1292|Truncated incorrect validate_password_length value: '-1'")) + tk.MustExec("set @@global.validate_password.number_count=-1") + tk.MustQuery("show warnings").Check(testkit.RowsWithSep("|", "Warning|1292|Truncated incorrect validate_password.number_count value: '-1'")) - tk.MustExec("set @@global.validate_password_length=8") - tk.MustQuery("show warnings").Check(testkit.Rows()) + tk.MustExec("set @@global.validate_password.length=-1") + tk.MustQuery("show warnings").Check(testkit.RowsWithSep("|", "Warning|1292|Truncated incorrect validate_password.length value: '-1'")) err = tk.ExecToErr("set @@tx_isolation=''") require.True(t, terror.ErrorEqual(err, variable.ErrWrongValueForVar), fmt.Sprintf("err %v", err)) diff --git a/executor/simple.go b/executor/simple.go index 3670670977a20..5953a3fa687f2 100644 --- a/executor/simple.go +++ b/executor/simple.go @@ -51,6 +51,7 @@ import ( "github.com/pingcap/tidb/util/collate" "github.com/pingcap/tidb/util/hack" "github.com/pingcap/tidb/util/logutil" + pwdValidator "github.com/pingcap/tidb/util/password-validation" "github.com/pingcap/tidb/util/sem" "github.com/pingcap/tidb/util/sqlexec" "github.com/pingcap/tidb/util/timeutil" @@ -783,6 +784,23 @@ func (e *SimpleExec) executeRollback(s *ast.RollbackStmt) error { return nil } +func (e *SimpleExec) authUsingCleartextPwd(authOpt *ast.AuthOption, authPlugin string) bool { + if authOpt == nil || !authOpt.ByAuthString { + return false + } + return authPlugin == mysql.AuthNativePassword || + authPlugin == mysql.AuthTiDBSM3Password || + authPlugin == mysql.AuthCachingSha2Password +} + +func (e *SimpleExec) isValidatePasswordEnabled() bool { + validatePwdEnable, err := e.ctx.GetSessionVars().GlobalVarsAccessor.GetGlobalSysVar(variable.ValidatePasswordEnable) + if err != nil { + return false + } + return variable.TiDBOptOn(validatePwdEnable) +} + func (e *SimpleExec) executeCreateUser(ctx context.Context, s *ast.CreateUserStmt) error { internalCtx := kv.WithInternalSourceType(context.Background(), kv.InternalTxnPrivilege) // Check `CREATE USER` privilege. @@ -874,15 +892,25 @@ func (e *SimpleExec) executeCreateUser(ctx context.Context, s *ast.CreateUserStm e.ctx.GetSessionVars().StmtCtx.AppendNote(err) continue } + authPlugin := mysql.AuthNativePassword + if spec.AuthOpt != nil && spec.AuthOpt.AuthPlugin != "" { + authPlugin = spec.AuthOpt.AuthPlugin + } + if e.isValidatePasswordEnabled() && !s.IsCreateRole { + if spec.AuthOpt == nil || !spec.AuthOpt.ByAuthString && spec.AuthOpt.HashString == "" { + return variable.ErrNotValidPassword.GenWithStackByArgs() + } + if e.authUsingCleartextPwd(spec.AuthOpt, authPlugin) { + if err := pwdValidator.ValidatePassword(e.ctx.GetSessionVars(), spec.AuthOpt.AuthString); err != nil { + return err + } + } + } pwd, ok := spec.EncodedPassword() if !ok { return errors.Trace(ErrPasswordFormat) } - authPlugin := mysql.AuthNativePassword - if spec.AuthOpt != nil && spec.AuthOpt.AuthPlugin != "" { - authPlugin = spec.AuthOpt.AuthPlugin - } switch authPlugin { case mysql.AuthNativePassword, mysql.AuthCachingSha2Password, mysql.AuthTiDBSM3Password, mysql.AuthSocket, mysql.AuthTiDBAuthToken: @@ -1071,11 +1099,11 @@ func (e *SimpleExec) executeAlterUser(ctx context.Context, s *ast.AlterUserStmt) var fields []alterField if spec.AuthOpt != nil { if spec.AuthOpt.AuthPlugin == "" { - authplugin, err := e.userAuthPlugin(spec.User.Username, spec.User.Hostname) + curAuthplugin, err := e.userAuthPlugin(spec.User.Username, spec.User.Hostname) if err != nil { return err } - spec.AuthOpt.AuthPlugin = authplugin + spec.AuthOpt.AuthPlugin = curAuthplugin } switch spec.AuthOpt.AuthPlugin { case mysql.AuthNativePassword, mysql.AuthCachingSha2Password, mysql.AuthTiDBSM3Password, mysql.AuthSocket, "": @@ -1087,6 +1115,11 @@ func (e *SimpleExec) executeAlterUser(ctx context.Context, s *ast.AlterUserStmt) default: return ErrPluginIsNotLoaded.GenWithStackByArgs(spec.AuthOpt.AuthPlugin) } + if e.isValidatePasswordEnabled() && e.authUsingCleartextPwd(spec.AuthOpt, spec.AuthOpt.AuthPlugin) { + if err := pwdValidator.ValidatePassword(e.ctx.GetSessionVars(), spec.AuthOpt.AuthString); err != nil { + return err + } + } pwd, ok := spec.EncodedPassword() if !ok { return errors.Trace(ErrPasswordFormat) @@ -1603,6 +1636,11 @@ func (e *SimpleExec) executeSetPwd(ctx context.Context, s *ast.SetPwdStmt) error if err != nil { return err } + if e.isValidatePasswordEnabled() { + if err := pwdValidator.ValidatePassword(e.ctx.GetSessionVars(), s.Password); err != nil { + return err + } + } var pwd string switch authplugin { case mysql.AuthCachingSha2Password, mysql.AuthTiDBSM3Password: diff --git a/executor/simple_test.go b/executor/simple_test.go index 13a439ad64d46..44c7691d805f8 100644 --- a/executor/simple_test.go +++ b/executor/simple_test.go @@ -129,3 +129,86 @@ func TestUserAttributes(t *testing.T) { rootTK.MustExec("alter user usr1 comment 'comment1'") rootTK.MustQuery("select user_attributes from mysql.user where user = 'usr1'").Check(testkit.Rows(`{"metadata": {"comment": "comment1"}}`)) } + +func TestValidatePassword(t *testing.T) { + store, _ := testkit.CreateMockStoreAndDomain(t) + tk := testkit.NewTestKit(t, store) + subtk := testkit.NewTestKit(t, store) + err := tk.Session().Auth(&auth.UserIdentity{Username: "root", Hostname: "%"}, nil, nil) + require.NoError(t, err) + tk.MustExec("CREATE USER ''@'localhost'") + tk.MustExec("GRANT ALL PRIVILEGES ON mysql.* TO ''@'localhost';") + err = subtk.Session().Auth(&auth.UserIdentity{Hostname: "localhost"}, nil, nil) + require.NoError(t, err) + + authPlugins := []string{mysql.AuthNativePassword, mysql.AuthCachingSha2Password, mysql.AuthTiDBSM3Password} + tk.MustQuery("SELECT @@global.validate_password.enable").Check(testkit.Rows("0")) + tk.MustExec("SET GLOBAL validate_password.enable = 1") + tk.MustQuery("SELECT @@global.validate_password.enable").Check(testkit.Rows("1")) + + for _, authPlugin := range authPlugins { + tk.MustExec("DROP USER IF EXISTS testuser") + tk.MustExec(fmt.Sprintf("CREATE USER testuser IDENTIFIED WITH %s BY '!Abc12345678'", authPlugin)) + + tk.MustExec("SET GLOBAL validate_password.policy = 'LOW'") + // check user name + tk.MustQuery("SELECT @@global.validate_password.check_user_name").Check(testkit.Rows("1")) + tk.MustContainErrMsg("ALTER USER testuser IDENTIFIED BY '!Abcdroot1234'", "Password Contains User Name") + tk.MustContainErrMsg("ALTER USER testuser IDENTIFIED BY '!Abcdtoor1234'", "Password Contains Reversed User Name") + tk.MustExec("SET PASSWORD FOR 'testuser' = 'testuser'") // password the same as the user name, but run by root + tk.MustExec("ALTER USER testuser IDENTIFIED BY 'testuser'") + tk.MustExec("SET GLOBAL validate_password.check_user_name = 0") + tk.MustExec("ALTER USER testuser IDENTIFIED BY '!Abcdroot1234'") + tk.MustExec("ALTER USER testuser IDENTIFIED BY '!Abcdtoor1234'") + tk.MustExec("SET GLOBAL validate_password.check_user_name = 1") + + // LOW: Length + tk.MustExec("SET GLOBAL validate_password.length = 8") + tk.MustQuery("SELECT @@global.validate_password.length").Check(testkit.Rows("8")) + tk.MustContainErrMsg("ALTER USER testuser IDENTIFIED BY '1234567'", "Require Password Length: 8") + tk.MustExec("SET GLOBAL validate_password.length = 12") + tk.MustContainErrMsg("ALTER USER testuser IDENTIFIED BY '!Abcdefg123'", "Require Password Length: 12") + tk.MustExec("ALTER USER testuser IDENTIFIED BY '!Abcdefg1234'") + tk.MustExec("SET GLOBAL validate_password.length = 8") + + // MEDIUM: Length; numeric, lowercase/uppercase, and special characters + tk.MustExec("SET GLOBAL validate_password.policy = 'MEDIUM'") + tk.MustExec("ALTER USER testuser IDENTIFIED BY '!Abc1234567'") + tk.MustContainErrMsg("ALTER USER testuser IDENTIFIED BY '!ABC1234567'", "Require Password Lowercase Count: 1") + tk.MustContainErrMsg("ALTER USER testuser IDENTIFIED BY '!abc1234567'", "Require Password Uppercase Count: 1") + tk.MustContainErrMsg("ALTER USER testuser IDENTIFIED BY '!ABCDabcd'", "Require Password Digit Count: 1") + tk.MustContainErrMsg("ALTER USER testuser IDENTIFIED BY 'Abc1234567'", "Require Password Non-alphanumeric Count: 1") + tk.MustExec("SET GLOBAL validate_password.special_char_count = 0") + tk.MustExec("ALTER USER testuser IDENTIFIED BY 'Abc1234567'") + tk.MustExec("SET GLOBAL validate_password.special_char_count = 1") + tk.MustExec("SET GLOBAL validate_password.length = 3") + tk.MustQuery("SELECT @@GLOBAL.validate_password.length").Check(testkit.Rows("4")) + + // STRONG: Length; numeric, lowercase/uppercase, and special characters; dictionary file + tk.MustExec("SET GLOBAL validate_password.policy = 'STRONG'") + tk.MustExec("ALTER USER testuser IDENTIFIED BY '!Abc1234567'") + tk.MustExec(fmt.Sprintf("SET GLOBAL validate_password.dictionary = '%s'", "1234;5678")) + tk.MustExec("ALTER USER testuser IDENTIFIED BY '!Abc123567'") + tk.MustExec("ALTER USER testuser IDENTIFIED BY '!Abc43218765'") + tk.MustContainErrMsg("ALTER USER testuser IDENTIFIED BY '!Abc1234567'", "Password contains word in the dictionary") + tk.MustExec("SET GLOBAL validate_password.dictionary = ''") + tk.MustExec("ALTER USER testuser IDENTIFIED BY '!Abc1234567'") + + // "IDENTIFIED AS 'xxx'" is not affected by validation + tk.MustExec(fmt.Sprintf("ALTER USER testuser IDENTIFIED WITH '%s' AS ''", authPlugin)) + } + tk.MustContainErrMsg("CREATE USER 'testuser1'@'localhost'", "Your password does not satisfy the current policy requirements") + tk.MustContainErrMsg("CREATE USER 'testuser1'@'localhost' IDENTIFIED WITH 'caching_sha2_password'", "Your password does not satisfy the current policy requirements") + tk.MustContainErrMsg("CREATE USER 'testuser1'@'localhost' IDENTIFIED WITH 'caching_sha2_password' AS ''", "Your password does not satisfy the current policy requirements") + + // if the username is '', all password can pass the check_user_name + subtk.MustQuery("SELECT user(), current_user()").Check(testkit.Rows("@localhost @localhost")) + subtk.MustQuery("SELECT @@global.validate_password.check_user_name").Check(testkit.Rows("1")) + subtk.MustQuery("SELECT @@global.validate_password.enable").Check(testkit.Rows("1")) + subtk.MustExec("ALTER USER ''@'localhost' IDENTIFIED BY ''") + subtk.MustExec("ALTER USER ''@'localhost' IDENTIFIED BY 'abcd'") + + // CREATE ROLE is not affected by password validation + tk.MustExec("SET GLOBAL validate_password.enable = 1") + tk.MustExec("CREATE ROLE role1") +} diff --git a/expression/BUILD.bazel b/expression/BUILD.bazel index 032c44054dba2..fc1752ef19e63 100644 --- a/expression/BUILD.bazel +++ b/expression/BUILD.bazel @@ -97,6 +97,7 @@ go_library( "//util/mathutil", "//util/mock", "//util/parser", + "//util/password-validation", "//util/plancodec", "//util/printer", "//util/sem", diff --git a/expression/builtin_encryption.go b/expression/builtin_encryption.go index a206a9d4970bb..fb451f9714cd4 100644 --- a/expression/builtin_encryption.go +++ b/expression/builtin_encryption.go @@ -37,6 +37,7 @@ import ( "github.com/pingcap/tidb/types" "github.com/pingcap/tidb/util/chunk" "github.com/pingcap/tidb/util/encrypt" + pwdValidator "github.com/pingcap/tidb/util/password-validation" "github.com/pingcap/tipb/go-tipb" ) @@ -73,6 +74,7 @@ var ( _ builtinFunc = &builtinSHA2Sig{} _ builtinFunc = &builtinUncompressSig{} _ builtinFunc = &builtinUncompressedLengthSig{} + _ builtinFunc = &builtinValidatePasswordStrengthSig{} ) // aesModeAttr indicates that the key length and iv attribute for specific block_encryption_mode. @@ -728,7 +730,6 @@ func (c *sm3FunctionClass) getFunction(ctx sessionctx.Context, args []Expression bf.tp.SetCollate(collate) bf.tp.SetFlen(40) sig := &builtinSM3Sig{bf} - //sig.setPbCode(tipb.ScalarFuncSig_SM3) // TODO return sig, nil } @@ -1010,5 +1011,66 @@ type validatePasswordStrengthFunctionClass struct { } func (c *validatePasswordStrengthFunctionClass) getFunction(ctx sessionctx.Context, args []Expression) (builtinFunc, error) { - return nil, errFunctionNotExists.GenWithStackByArgs("FUNCTION", "VALIDATE_PASSWORD_STRENGTH") + if err := c.verifyArgs(args); err != nil { + return nil, err + } + bf, err := newBaseBuiltinFuncWithTp(ctx, c.funcName, args, types.ETInt, types.ETString) + if err != nil { + return nil, err + } + bf.tp.SetFlen(21) + sig := &builtinValidatePasswordStrengthSig{bf} + return sig, nil +} + +type builtinValidatePasswordStrengthSig struct { + baseBuiltinFunc +} + +func (b *builtinValidatePasswordStrengthSig) Clone() builtinFunc { + newSig := &builtinValidatePasswordStrengthSig{} + newSig.cloneFrom(&b.baseBuiltinFunc) + return newSig +} + +// evalInt evals VALIDATE_PASSWORD_STRENGTH(str). +// See https://dev.mysql.com/doc/refman/8.0/en/encryption-functions.html#function_validate-password-strength +func (b *builtinValidatePasswordStrengthSig) evalInt(row chunk.Row) (int64, bool, error) { + globalVars := b.ctx.GetSessionVars().GlobalVarsAccessor + str, isNull, err := b.args[0].EvalString(b.ctx, row) + if err != nil || isNull { + return 0, true, err + } else if len([]rune(str)) < 4 { + return 0, false, nil + } + if validation, err := globalVars.GetGlobalSysVar(variable.ValidatePasswordEnable); err != nil { + return 0, true, err + } else if !variable.TiDBOptOn(validation) { + return 0, false, nil + } + return b.validateStr(str, &globalVars) +} + +func (b *builtinValidatePasswordStrengthSig) validateStr(str string, globalVars *variable.GlobalVarAccessor) (int64, bool, error) { + if warn, err := pwdValidator.ValidateUserNameInPassword(str, b.ctx.GetSessionVars()); err != nil { + return 0, true, err + } else if len(warn) > 0 { + return 0, false, nil + } + if warn, err := pwdValidator.ValidatePasswordLowPolicy(str, globalVars); err != nil { + return 0, true, err + } else if len(warn) > 0 { + return 25, false, nil + } + if warn, err := pwdValidator.ValidatePasswordMediumPolicy(str, globalVars); err != nil { + return 0, true, err + } else if len(warn) > 0 { + return 50, false, nil + } + if ok, err := pwdValidator.ValidateDictionaryPassword(str, globalVars); err != nil { + return 0, true, err + } else if !ok { + return 75, false, nil + } + return 100, false, nil } diff --git a/expression/builtin_encryption_test.go b/expression/builtin_encryption_test.go index 0f74ab611aa48..087fb3f35e466 100644 --- a/expression/builtin_encryption_test.go +++ b/expression/builtin_encryption_test.go @@ -15,12 +15,14 @@ package expression import ( + "context" "encoding/hex" "fmt" "strings" "testing" "github.com/pingcap/tidb/parser/ast" + "github.com/pingcap/tidb/parser/auth" "github.com/pingcap/tidb/parser/charset" "github.com/pingcap/tidb/parser/mysql" "github.com/pingcap/tidb/parser/terror" @@ -631,6 +633,55 @@ func TestUncompressLength(t *testing.T) { } } +func TestValidatePasswordStrength(t *testing.T) { + ctx := createContext(t) + ctx.GetSessionVars().User = &auth.UserIdentity{Username: "testuser"} + globalVarsAccessor := variable.NewMockGlobalAccessor4Tests() + ctx.GetSessionVars().GlobalVarsAccessor = globalVarsAccessor + err := globalVarsAccessor.SetGlobalSysVar(context.Background(), variable.ValidatePasswordDictionary, "1234") + require.NoError(t, err) + + tests := []struct { + in interface{} + expect interface{} + }{ + {nil, nil}, + {"123", 0}, + {"testuser123", 0}, + {"resutset123", 0}, + {"12345", 25}, + {"12345678", 50}, + {"!Abc12345678", 75}, + {"!Abc87654321", 100}, + } + + fc := funcs[ast.ValidatePasswordStrength] + // disable password validation + for _, test := range tests { + arg := types.NewDatum(test.in) + f, err := fc.getFunction(ctx, datumsToConstants([]types.Datum{arg})) + require.NoErrorf(t, err, "%v", test) + out, err := evalBuiltinFunc(f, chunk.Row{}) + require.NoErrorf(t, err, "%v", test) + if test.expect == nil { + require.Equal(t, types.NewDatum(nil), out) + } else { + require.Equalf(t, types.NewDatum(0), out, "%v", test) + } + } + // enable password validation + err = globalVarsAccessor.SetGlobalSysVar(context.Background(), variable.ValidatePasswordEnable, "ON") + require.NoError(t, err) + for _, test := range tests { + arg := types.NewDatum(test.in) + f, err := fc.getFunction(ctx, datumsToConstants([]types.Datum{arg})) + require.NoErrorf(t, err, "%v", test) + out, err := evalBuiltinFunc(f, chunk.Row{}) + require.NoErrorf(t, err, "%v", test) + require.Equalf(t, types.NewDatum(test.expect), out, "%v", test) + } +} + func TestPassword(t *testing.T) { ctx := createContext(t) cases := []struct { diff --git a/expression/builtin_encryption_vec.go b/expression/builtin_encryption_vec.go index e9a1d45ae67be..ff71913f8d70b 100644 --- a/expression/builtin_encryption_vec.go +++ b/expression/builtin_encryption_vec.go @@ -30,6 +30,7 @@ import ( "github.com/pingcap/errors" "github.com/pingcap/tidb/parser/auth" + "github.com/pingcap/tidb/sessionctx/variable" "github.com/pingcap/tidb/types" "github.com/pingcap/tidb/util/chunk" "github.com/pingcap/tidb/util/encrypt" @@ -863,3 +864,45 @@ func (b *builtinUncompressedLengthSig) vecEvalInt(input *chunk.Chunk, result *ch } return nil } + +func (b *builtinValidatePasswordStrengthSig) vectorized() bool { + return true +} + +func (b *builtinValidatePasswordStrengthSig) vecEvalInt(input *chunk.Chunk, result *chunk.Column) error { + n := input.NumRows() + buf, err := b.bufAllocator.get() + if err != nil { + return err + } + defer b.bufAllocator.put(buf) + if err := b.args[0].VecEvalString(b.ctx, input, buf); err != nil { + return err + } + + result.ResizeInt64(n, false) + result.MergeNulls(buf) + i64s := result.Int64s() + globalVars := b.ctx.GetSessionVars().GlobalVarsAccessor + enableValidation := false + validation, err := globalVars.GetGlobalSysVar(variable.ValidatePasswordEnable) + if err != nil { + return err + } + enableValidation = variable.TiDBOptOn(validation) + for i := 0; i < n; i++ { + if result.IsNull(i) { + continue + } + if !enableValidation { + i64s[i] = 0 + } else if score, isNull, err := b.validateStr(buf.GetString(i), &globalVars); err != nil { + return err + } else if !isNull { + i64s[i] = score + } else { + result.SetNull(i, true) + } + } + return nil +} diff --git a/expression/builtin_encryption_vec_test.go b/expression/builtin_encryption_vec_test.go index c6caa1eb60d51..46395e51bcb6b 100644 --- a/expression/builtin_encryption_vec_test.go +++ b/expression/builtin_encryption_vec_test.go @@ -75,6 +75,9 @@ var vecBuiltinEncryptionCases = map[string][]vecExprBenchCase{ ast.Decode: { {retEvalType: types.ETString, childrenTypes: []types.EvalType{types.ETString, types.ETString}, geners: []dataGenerator{newRandLenStrGener(10, 20)}}, }, + ast.ValidatePasswordStrength: { + {retEvalType: types.ETInt, childrenTypes: []types.EvalType{types.ETString}}, + }, } func TestVectorizedBuiltinEncryptionFunc(t *testing.T) { diff --git a/expression/integration_test.go b/expression/integration_test.go index a0e2f93b3103f..bb5bddfa7d9a4 100644 --- a/expression/integration_test.go +++ b/expression/integration_test.go @@ -968,6 +968,7 @@ func TestEncryptionBuiltin(t *testing.T) { tk := testkit.NewTestKit(t, store) tk.MustExec("use test") + tk.Session().GetSessionVars().User = &auth.UserIdentity{Username: "root"} ctx := context.Background() // for password @@ -1143,6 +1144,25 @@ func TestEncryptionBuiltin(t *testing.T) { tk.MustQuery("SELECT RANDOM_BYTES(1024);") result = tk.MustQuery("SELECT RANDOM_BYTES(NULL);") result.Check(testkit.Rows("")) + + // for VALIDATE_PASSWORD_STRENGTH + tk.MustExec(fmt.Sprintf("SET GLOBAL validate_password.dictionary='%s'", "password")) + tk.MustExec("SET GLOBAL validate_password.enable = 1") + tk.MustQuery("SELECT validate_password_strength('root')").Check(testkit.Rows("0")) + tk.MustQuery("SELECT validate_password_strength('toor')").Check(testkit.Rows("0")) + tk.MustQuery("SELECT validate_password_strength('ROOT')").Check(testkit.Rows("25")) + tk.MustQuery("SELECT validate_password_strength('TOOR')").Check(testkit.Rows("25")) + tk.MustQuery("SELECT validate_password_strength('fooHoHo%1')").Check(testkit.Rows("100")) + tk.MustQuery("SELECT validate_password_strength('pass')").Check(testkit.Rows("25")) + tk.MustQuery("SELECT validate_password_strength('password')").Check(testkit.Rows("50")) + tk.MustQuery("SELECT validate_password_strength('password0000')").Check(testkit.Rows("50")) + tk.MustQuery("SELECT validate_password_strength('password1A#')").Check(testkit.Rows("75")) + tk.MustQuery("SELECT validate_password_strength('PA12wrd!#')").Check(testkit.Rows("100")) + tk.MustQuery("SELECT VALIDATE_PASSWORD_STRENGTH(REPEAT(\"aA1#\", 26))").Check(testkit.Rows("100")) + tk.MustQuery("SELECT validate_password_strength(null)").Check(testkit.Rows("")) + tk.MustQuery("SELECT validate_password_strength('null')").Check(testkit.Rows("25")) + tk.MustQuery("SELECT VALIDATE_PASSWORD_STRENGTH( 0x6E616E646F73617135234552 )").Check(testkit.Rows("100")) + tk.MustQuery("SELECT VALIDATE_PASSWORD_STRENGTH(CAST(0xd2 AS BINARY(10)))").Check(testkit.Rows("50")) } func TestOpBuiltin(t *testing.T) { diff --git a/go.mod b/go.mod index a0725a1599c30..2cd460e8279c3 100644 --- a/go.mod +++ b/go.mod @@ -25,7 +25,7 @@ require ( github.com/coreos/go-semver v0.3.0 github.com/daixiang0/gci v0.8.5 github.com/danjacques/gofslock v0.0.0-20191023191349-0a45f885bc37 - github.com/dgraph-io/ristretto v0.1.1-0.20220403145359-8e850b710d6d + github.com/dgraph-io/ristretto v0.1.1 github.com/dgryski/go-farm v0.0.0-20200201041132-a6ae2369ad13 github.com/docker/go-units v0.4.0 github.com/emirpasic/gods v1.18.1 @@ -110,7 +110,7 @@ require ( golang.org/x/net v0.1.0 golang.org/x/oauth2 v0.0.0-20220411215720-9780585627b5 golang.org/x/sync v0.1.0 - golang.org/x/sys v0.1.0 + golang.org/x/sys v0.2.0 golang.org/x/term v0.1.0 golang.org/x/text v0.4.0 golang.org/x/time v0.1.0 diff --git a/go.sum b/go.sum index e3f2b0b490ae7..bdc73a0d74ddc 100644 --- a/go.sum +++ b/go.sum @@ -231,8 +231,8 @@ github.com/decred/dcrd/crypto/blake256 v1.0.0/go.mod h1:sQl2p6Y26YV+ZOcSTP6thNdn github.com/decred/dcrd/dcrec/secp256k1/v4 v4.1.0 h1:HbphB4TFFXpv7MNrT52FGrrgVXF1owhMVTHFZIlnvd4= github.com/decred/dcrd/dcrec/secp256k1/v4 v4.1.0/go.mod h1:DZGJHZMqrU4JJqFAWUS2UO1+lbSKsdiOoYi9Zzey7Fc= github.com/dgraph-io/badger v1.6.0/go.mod h1:zwt7syl517jmP8s94KqSxTlM6IMsdhYy6psNgSztDR4= -github.com/dgraph-io/ristretto v0.1.1-0.20220403145359-8e850b710d6d h1:Wrc3UKTS+cffkOx0xRGFC+ZesNuTfn0ThvEC72N0krk= -github.com/dgraph-io/ristretto v0.1.1-0.20220403145359-8e850b710d6d/go.mod h1:RAy2GVV4sTWVlNMavv3xhLsk18rxhfhDnombTe6EF5c= +github.com/dgraph-io/ristretto v0.1.1 h1:6CWw5tJNgpegArSHpNHJKldNeq03FQCwYvfMVWajOK8= +github.com/dgraph-io/ristretto v0.1.1/go.mod h1:S1GPSBCYCIhmVNfcth17y2zZtQT6wzkzgwUve0VDWWA= github.com/dgryski/go-farm v0.0.0-20190104051053-3adb47b1fb0f/go.mod h1:SqUrOPUnsFjfmXRMNPybcSiG0BgUW2AuFH8PAnS2iTw= github.com/dgryski/go-farm v0.0.0-20190423205320-6a90982ecee2/go.mod h1:SqUrOPUnsFjfmXRMNPybcSiG0BgUW2AuFH8PAnS2iTw= github.com/dgryski/go-farm v0.0.0-20200201041132-a6ae2369ad13 h1:fAjc9m62+UWV/WAFKLNi6ZS0675eEUC9y3AlwSbQu1Y= @@ -1311,8 +1311,9 @@ golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBc golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220909162455-aba9fc2a8ff2/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.1.0 h1:kunALQeHf1/185U1i0GOB/fy1IPRDDpuoOOqRReG57U= -golang.org/x/sys v0.1.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20221010170243-090e33056c14/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.2.0 h1:ljd4t30dBnAvMZaQCevtY0xLLD0A+bRZXbgLMLU1F/A= +golang.org/x/sys v0.2.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= golang.org/x/term v0.1.0 h1:g6Z6vPFA9dYBAF7DWcH6sCcOntplXsDKcliusYijMlw= diff --git a/sessionctx/variable/error.go b/sessionctx/variable/error.go index 60928932f0f06..f760cba8bfcd5 100644 --- a/sessionctx/variable/error.go +++ b/sessionctx/variable/error.go @@ -39,6 +39,7 @@ var ( errLocalVariable = dbterror.ClassVariable.NewStd(mysql.ErrLocalVariable) errValueNotSupportedWhen = dbterror.ClassVariable.NewStdErr(mysql.ErrNotSupportedYet, pmysql.Message("%s = OFF is not supported when %s = ON", nil)) ErrStmtNotFound = dbterror.ClassOptimizer.NewStd(mysql.ErrPreparedStmtNotFound) + ErrNotValidPassword = dbterror.ClassExecutor.NewStd(mysql.ErrNotValidPassword) // ErrFunctionsNoopImpl is an error to say the behavior is protected by the tidb_enable_noop_functions sysvar. // This is copied from expression.ErrFunctionsNoopImpl to prevent circular dependencies. // It needs to be public for tests. diff --git a/sessionctx/variable/noop.go b/sessionctx/variable/noop.go index 398ea09f3ec92..5505fe65a3623 100644 --- a/sessionctx/variable/noop.go +++ b/sessionctx/variable/noop.go @@ -58,8 +58,6 @@ var noopSysVars = []*SysVar{ {Scope: ScopeGlobal | ScopeSession, Name: BigTables, Value: Off, Type: TypeBool}, {Scope: ScopeNone, Name: "skip_external_locking", Value: "1"}, {Scope: ScopeNone, Name: "innodb_sync_array_size", Value: "1"}, - {Scope: ScopeGlobal, Name: ValidatePasswordCheckUserName, Value: Off, Type: TypeBool}, - {Scope: ScopeGlobal, Name: ValidatePasswordNumberCount, Value: "1", Type: TypeUnsigned, MinValue: 0, MaxValue: math.MaxUint64}, {Scope: ScopeSession, Name: "gtid_next", Value: ""}, {Scope: ScopeGlobal, Name: "ndb_show_foreign_key_mock_tables", Value: ""}, {Scope: ScopeNone, Name: "multi_range_count", Value: "256"}, @@ -463,7 +461,6 @@ var noopSysVars = []*SysVar{ {Scope: ScopeGlobal | ScopeSession, Name: "eq_range_index_dive_limit", Value: "200", IsHintUpdatable: true}, {Scope: ScopeNone, Name: "performance_schema_events_stages_history_size", Value: "10"}, {Scope: ScopeGlobal | ScopeSession, Name: "ndb_join_pushdown", Value: ""}, - {Scope: ScopeGlobal, Name: "validate_password_special_char_count", Value: "1"}, {Scope: ScopeNone, Name: "performance_schema_max_thread_instances", Value: "402"}, {Scope: ScopeGlobal | ScopeSession, Name: "ndbinfo_show_hidden", Value: ""}, {Scope: ScopeGlobal | ScopeSession, Name: "net_read_timeout", Value: "30"}, @@ -472,7 +469,6 @@ var noopSysVars = []*SysVar{ {Scope: ScopeGlobal, Name: "sync_relay_log_info", Value: "10000"}, {Scope: ScopeGlobal | ScopeSession, Name: "optimizer_trace_limit", Value: "1"}, {Scope: ScopeNone, Name: "innodb_ft_max_token_size", Value: "84"}, - {Scope: ScopeGlobal, Name: ValidatePasswordLength, Value: "8", Type: TypeUnsigned, MinValue: 0, MaxValue: math.MaxUint64}, {Scope: ScopeGlobal, Name: "ndb_log_binlog_index", Value: ""}, {Scope: ScopeGlobal, Name: "innodb_api_bk_commit_interval", Value: "5"}, {Scope: ScopeNone, Name: "innodb_undo_directory", Value: "."}, diff --git a/sessionctx/variable/sysvar.go b/sessionctx/variable/sysvar.go index b3902d8f0e431..060c542bddd77 100644 --- a/sessionctx/variable/sysvar.go +++ b/sessionctx/variable/sysvar.go @@ -486,6 +486,86 @@ var defaultSysVars = []*SysVar{ } return normalizedValue, nil }}, + {Scope: ScopeGlobal, Name: ValidatePasswordEnable, Value: Off, Type: TypeBool}, + {Scope: ScopeGlobal, Name: ValidatePasswordPolicy, Value: "MEDIUM", Type: TypeEnum, PossibleValues: []string{"LOW", "MEDIUM", "STRONG"}}, + {Scope: ScopeGlobal, Name: ValidatePasswordCheckUserName, Value: On, Type: TypeBool}, + {Scope: ScopeGlobal, Name: ValidatePasswordLength, Value: "8", Type: TypeInt, MinValue: 0, MaxValue: math.MaxInt32, + Validation: func(vars *SessionVars, normalizedValue string, originalValue string, scope ScopeFlag) (string, error) { + numberCount, specialCharCount, mixedCaseCount := PasswordValidtaionNumberCount.Load(), PasswordValidationSpecialCharCount.Load(), PasswordValidationMixedCaseCount.Load() + length, err := strconv.ParseInt(normalizedValue, 10, 32) + if err != nil { + return "", err + } + if minLength := numberCount + specialCharCount + 2*mixedCaseCount; int32(length) < minLength { + return strconv.FormatInt(int64(minLength), 10), nil + } + return normalizedValue, nil + }, + SetGlobal: func(_ context.Context, s *SessionVars, val string) error { + PasswordValidationLength.Store(int32(TidbOptInt64(val, 8))) + return nil + }, GetGlobal: func(_ context.Context, s *SessionVars) (string, error) { + return strconv.FormatInt(int64(PasswordValidationLength.Load()), 10), nil + }, + }, + {Scope: ScopeGlobal, Name: ValidatePasswordMixedCaseCount, Value: "1", Type: TypeInt, MinValue: 0, MaxValue: math.MaxInt32, + Validation: func(vars *SessionVars, normalizedValue string, originalValue string, scope ScopeFlag) (string, error) { + length, numberCount, specialCharCount := PasswordValidationLength.Load(), PasswordValidtaionNumberCount.Load(), PasswordValidationSpecialCharCount.Load() + mixedCaseCount, err := strconv.ParseInt(normalizedValue, 10, 32) + if err != nil { + return "", err + } + if minLength := numberCount + specialCharCount + 2*int32(mixedCaseCount); length < minLength { + PasswordValidationLength.Store(minLength) + } + return normalizedValue, nil + }, + SetGlobal: func(_ context.Context, s *SessionVars, val string) error { + PasswordValidationMixedCaseCount.Store(int32(TidbOptInt64(val, 1))) + return nil + }, GetGlobal: func(_ context.Context, s *SessionVars) (string, error) { + return strconv.FormatInt(int64(PasswordValidationMixedCaseCount.Load()), 10), nil + }, + }, + {Scope: ScopeGlobal, Name: ValidatePasswordNumberCount, Value: "1", Type: TypeInt, MinValue: 0, MaxValue: math.MaxInt32, + Validation: func(vars *SessionVars, normalizedValue string, originalValue string, scope ScopeFlag) (string, error) { + length, specialCharCount, mixedCaseCount := PasswordValidationLength.Load(), PasswordValidationSpecialCharCount.Load(), PasswordValidationMixedCaseCount.Load() + numberCount, err := strconv.ParseInt(normalizedValue, 10, 32) + if err != nil { + return "", err + } + if minLength := int32(numberCount) + specialCharCount + 2*mixedCaseCount; length < minLength { + PasswordValidationLength.Store(minLength) + } + return normalizedValue, nil + }, + SetGlobal: func(_ context.Context, s *SessionVars, val string) error { + PasswordValidtaionNumberCount.Store(int32(TidbOptInt64(val, 1))) + return nil + }, GetGlobal: func(_ context.Context, s *SessionVars) (string, error) { + return strconv.FormatInt(int64(PasswordValidtaionNumberCount.Load()), 10), nil + }, + }, + {Scope: ScopeGlobal, Name: ValidatePasswordSpecialCharCount, Value: "1", Type: TypeInt, MinValue: 0, MaxValue: math.MaxInt32, + Validation: func(vars *SessionVars, normalizedValue string, originalValue string, scope ScopeFlag) (string, error) { + length, numberCount, mixedCaseCount := PasswordValidationLength.Load(), PasswordValidtaionNumberCount.Load(), PasswordValidationMixedCaseCount.Load() + specialCharCount, err := strconv.ParseInt(normalizedValue, 10, 32) + if err != nil { + return "", err + } + if minLength := numberCount + int32(specialCharCount) + 2*mixedCaseCount; length < minLength { + PasswordValidationLength.Store(minLength) + } + return normalizedValue, nil + }, + SetGlobal: func(_ context.Context, s *SessionVars, val string) error { + PasswordValidationSpecialCharCount.Store(int32(TidbOptInt64(val, 1))) + return nil + }, GetGlobal: func(_ context.Context, s *SessionVars) (string, error) { + return strconv.FormatInt(int64(PasswordValidationSpecialCharCount.Load()), 10), nil + }, + }, + {Scope: ScopeGlobal, Name: ValidatePasswordDictionary, Value: "", Type: TypeStr}, /* TiDB specific variables */ {Scope: ScopeGlobal, Name: TiDBTSOClientBatchMaxWaitTime, Value: strconv.FormatFloat(DefTiDBTSOClientBatchMaxWaitTime, 'f', -1, 64), Type: TypeFloat, MinValue: 0, MaxValue: 10, @@ -1047,7 +1127,7 @@ var defaultSysVars = []*SysVar{ MemoryUsageAlarmKeepRecordNum.Store(TidbOptInt64(val, DefMemoryUsageAlarmKeepRecordNum)) return nil }, GetGlobal: func(_ context.Context, s *SessionVars) (string, error) { - return fmt.Sprintf("%d", MemoryUsageAlarmKeepRecordNum.Load()), nil + return strconv.FormatInt(MemoryUsageAlarmKeepRecordNum.Load(), 10), nil }}, /* The system variables below have GLOBAL and SESSION scope */ @@ -2125,10 +2205,6 @@ const ( BlockEncryptionMode = "block_encryption_mode" // WaitTimeout is the name for 'wait_timeout' system variable. WaitTimeout = "wait_timeout" - // ValidatePasswordNumberCount is the name of 'validate_password_number_count' system variable. - ValidatePasswordNumberCount = "validate_password_number_count" - // ValidatePasswordLength is the name of 'validate_password_length' system variable. - ValidatePasswordLength = "validate_password_length" // Version is the name of 'version' system variable. Version = "version" // VersionComment is the name of 'version_comment' system variable. @@ -2151,8 +2227,6 @@ const ( BinlogOrderCommits = "binlog_order_commits" // MasterVerifyChecksum is the name for 'master_verify_checksum' system variable. MasterVerifyChecksum = "master_verify_checksum" - // ValidatePasswordCheckUserName is the name for 'validate_password_check_user_name' system variable. - ValidatePasswordCheckUserName = "validate_password_check_user_name" // SuperReadOnly is the name for 'super_read_only' system variable. SuperReadOnly = "super_read_only" // SQLNotes is the name for 'sql_notes' system variable. @@ -2319,4 +2393,21 @@ const ( RandSeed2 = "rand_seed2" // SQLRequirePrimaryKey is the name of `sql_require_primary_key` system variable. SQLRequirePrimaryKey = "sql_require_primary_key" + // ValidatePasswordEnable turns on/off the validation of password. + ValidatePasswordEnable = "validate_password.enable" + // ValidatePasswordPolicy specifies the password policy enforced by validate_password. + ValidatePasswordPolicy = "validate_password.policy" + // ValidatePasswordCheckUserName controls whether validate_password compares passwords to the user name part of + // the effective user account for the current session + ValidatePasswordCheckUserName = "validate_password.check_user_name" + // ValidatePasswordLength specified the minimum number of characters that validate_password requires passwords to have + ValidatePasswordLength = "validate_password.length" + // ValidatePasswordMixedCaseCount specified the minimum number of lowercase and uppercase characters that validate_password requires + ValidatePasswordMixedCaseCount = "validate_password.mixed_case_count" + // ValidatePasswordNumberCount specified the minimum number of numeric (digit) characters that validate_password requires + ValidatePasswordNumberCount = "validate_password.number_count" + // ValidatePasswordSpecialCharCount specified the minimum number of nonalphanumeric characters that validate_password requires + ValidatePasswordSpecialCharCount = "validate_password.special_char_count" + // ValidatePasswordDictionary specified the dictionary that validate_password uses for checking passwords. Each word is separated by semicolon (;). + ValidatePasswordDictionary = "validate_password.dictionary" ) diff --git a/sessionctx/variable/tidb_vars.go b/sessionctx/variable/tidb_vars.go index e4fa0b671cebe..a9e278107d270 100644 --- a/sessionctx/variable/tidb_vars.go +++ b/sessionctx/variable/tidb_vars.go @@ -1150,6 +1150,11 @@ var ( // It should be a const and shouldn't be modified after tidb is started. DefTiDBServerMemoryLimit = serverMemoryLimitDefaultValue() GOGCTunerThreshold = atomic.NewFloat64(DefTiDBGOGCTunerThreshold) + + PasswordValidationLength = atomic.NewInt32(8) + PasswordValidationMixedCaseCount = atomic.NewInt32(1) + PasswordValidtaionNumberCount = atomic.NewInt32(1) + PasswordValidationSpecialCharCount = atomic.NewInt32(1) ) var ( diff --git a/tests/realtikvtest/addindextest/integration_test.go b/tests/realtikvtest/addindextest/integration_test.go index 7427f935c78ca..70da49e58364b 100644 --- a/tests/realtikvtest/addindextest/integration_test.go +++ b/tests/realtikvtest/addindextest/integration_test.go @@ -187,6 +187,8 @@ func TestAddIndexIngestAdjustBackfillWorkerCountFail(t *testing.T) { tk.MustExec("create database addindexlit;") tk.MustExec("use addindexlit;") tk.MustExec(`set global tidb_ddl_enable_fast_reorg=on;`) + ingest.ImporterRangeConcurrencyForTest = &atomic.Int32{} + ingest.ImporterRangeConcurrencyForTest.Store(2) tk.MustExec("set @@global.tidb_ddl_reorg_worker_cnt = 20;") tk.MustExec("create table t (a int primary key);") var sb strings.Builder @@ -205,4 +207,52 @@ func TestAddIndexIngestAdjustBackfillWorkerCountFail(t *testing.T) { jobTp := rows[0][3].(string) require.True(t, strings.Contains(jobTp, "ingest"), jobTp) tk.MustExec("set @@global.tidb_ddl_reorg_worker_cnt = 4;") + ingest.ImporterRangeConcurrencyForTest = nil +} + +func TestAddIndexIngestGeneratedColumns(t *testing.T) { + store := realtikvtest.CreateMockStoreAndSetup(t) + tk := testkit.NewTestKit(t, store) + tk.MustExec("drop database if exists addindexlit;") + tk.MustExec("create database addindexlit;") + tk.MustExec("use addindexlit;") + tk.MustExec(`set global tidb_ddl_enable_fast_reorg=on;`) + assertLastNDDLUseIngest := func(n int) { + tk.MustExec("admin check table t;") + rows := tk.MustQuery(fmt.Sprintf("admin show ddl jobs %d;", n)).Rows() + require.Len(t, rows, n) + for i := 0; i < n; i++ { + jobTp := rows[i][3].(string) + require.True(t, strings.Contains(jobTp, "ingest"), jobTp) + } + } + tk.MustExec("create table t (a int, b int, c int as (b+10), d int as (b+c), primary key (a) clustered);") + tk.MustExec("insert into t (a, b) values (1, 1), (2, 2), (3, 3);") + tk.MustExec("alter table t add index idx(c);") + tk.MustExec("alter table t add index idx1(c, a);") + tk.MustExec("alter table t add index idx2(a);") + tk.MustExec("alter table t add index idx3(d);") + tk.MustExec("alter table t add index idx4(d, c);") + tk.MustQuery("select * from t;").Check(testkit.Rows("1 1 11 12", "2 2 12 14", "3 3 13 16")) + assertLastNDDLUseIngest(5) + + tk.MustExec("drop table if exists t;") + tk.MustExec("create table t (a int, b char(10), c char(10) as (concat(b, 'x')), d int, e char(20) as (c));") + tk.MustExec("insert into t (a, b, d) values (1, '1', 1), (2, '2', 2), (3, '3', 3);") + tk.MustExec("alter table t add index idx(c);") + tk.MustExec("alter table t add index idx1(a, c);") + tk.MustExec("alter table t add index idx2(c(7));") + tk.MustExec("alter table t add index idx3(e(5));") + tk.MustQuery("select * from t;").Check(testkit.Rows("1 1 1x 1 1x", "2 2 2x 2 2x", "3 3 3x 3 3x")) + assertLastNDDLUseIngest(4) + + tk.MustExec("drop table if exists t;") + tk.MustExec("create table t (a int, b char(10), c tinyint, d int as (a + c), e bigint as (d - a), primary key(b, a) clustered);") + tk.MustExec("insert into t (a, b, c) values (1, '1', 1), (2, '2', 2), (3, '3', 3);") + tk.MustExec("alter table t add index idx(d);") + tk.MustExec("alter table t add index idx1(b(2), d);") + tk.MustExec("alter table t add index idx2(d, c);") + tk.MustExec("alter table t add index idx3(e);") + tk.MustQuery("select * from t;").Check(testkit.Rows("1 1 1 2 1", "2 2 2 4 2", "3 3 3 6 3")) + assertLastNDDLUseIngest(4) } diff --git a/util/password-validation/BUILD.bazel b/util/password-validation/BUILD.bazel new file mode 100644 index 0000000000000..c3649a3a15383 --- /dev/null +++ b/util/password-validation/BUILD.bazel @@ -0,0 +1,23 @@ +load("@io_bazel_rules_go//go:def.bzl", "go_library", "go_test") + +go_library( + name = "password-validation", + srcs = ["password_validation.go"], + importpath = "github.com/pingcap/tidb/util/password-validation", + visibility = ["//visibility:public"], + deps = [ + "//sessionctx/variable", + "//util/hack", + ], +) + +go_test( + name = "password-validation_test", + srcs = ["password_validation_test.go"], + embed = [":password-validation"], + deps = [ + "//parser/auth", + "//sessionctx/variable", + "@com_github_stretchr_testify//require", + ], +) diff --git a/util/password-validation/password_validation.go b/util/password-validation/password_validation.go new file mode 100644 index 0000000000000..edd0bd39ec38a --- /dev/null +++ b/util/password-validation/password_validation.go @@ -0,0 +1,175 @@ +// Copyright 2022 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package validator + +import ( + "bytes" + "fmt" + "strconv" + "strings" + "unicode" + + "github.com/pingcap/tidb/sessionctx/variable" + "github.com/pingcap/tidb/util/hack" +) + +const maxPwdValidationLength int = 100 + +const minPwdValidationLength int = 4 + +// ValidateDictionaryPassword checks if the password contains words in the dictionary. +func ValidateDictionaryPassword(pwd string, globalVars *variable.GlobalVarAccessor) (bool, error) { + dictionary, err := (*globalVars).GetGlobalSysVar(variable.ValidatePasswordDictionary) + if err != nil { + return false, err + } + words := strings.Split(dictionary, ";") + if len(words) == 0 { + return true, nil + } + pwd = strings.ToLower(pwd) + for _, word := range words { + if len(word) >= minPwdValidationLength && len(word) <= maxPwdValidationLength { + if strings.Contains(pwd, strings.ToLower(word)) { + return false, nil + } + } + } + return true, nil +} + +// ValidateUserNameInPassword checks whether pwd exists in the dictionary. +func ValidateUserNameInPassword(pwd string, sessionVars *variable.SessionVars) (string, error) { + currentUser := sessionVars.User + globalVars := sessionVars.GlobalVarsAccessor + pwdBytes := hack.Slice(pwd) + if checkUserName, err := globalVars.GetGlobalSysVar(variable.ValidatePasswordCheckUserName); err != nil { + return "", err + } else if currentUser != nil && variable.TiDBOptOn(checkUserName) { + for _, username := range []string{currentUser.AuthUsername, currentUser.Username} { + usernameBytes := hack.Slice(username) + userNameLen := len(usernameBytes) + if userNameLen == 0 { + continue + } + if bytes.Contains(pwdBytes, usernameBytes) { + return "Password Contains User Name", nil + } + usernameReversedBytes := make([]byte, userNameLen) + for i := range usernameBytes { + usernameReversedBytes[i] = usernameBytes[userNameLen-1-i] + } + if bytes.Contains(pwdBytes, usernameReversedBytes) { + return "Password Contains Reversed User Name", nil + } + } + } + return "", nil +} + +// ValidatePasswordLowPolicy checks whether pwd satisfies the low policy of password validation. +func ValidatePasswordLowPolicy(pwd string, globalVars *variable.GlobalVarAccessor) (string, error) { + if validateLengthStr, err := (*globalVars).GetGlobalSysVar(variable.ValidatePasswordLength); err != nil { + return "", err + } else if validateLength, err := strconv.ParseInt(validateLengthStr, 10, 64); err != nil { + return "", err + } else if (int64)(len([]rune(pwd))) < validateLength { + return fmt.Sprintf("Require Password Length: %d", validateLength), nil + } + return "", nil +} + +// ValidatePasswordMediumPolicy checks whether pwd satisfies the medium policy of password validation. +func ValidatePasswordMediumPolicy(pwd string, globalVars *variable.GlobalVarAccessor) (string, error) { + var lowerCaseCount, upperCaseCount, numberCount, specialCharCount int64 + runes := []rune(pwd) + for i := 0; i < len(runes); i++ { + if unicode.IsUpper(runes[i]) { + upperCaseCount++ + } else if unicode.IsLower(runes[i]) { + lowerCaseCount++ + } else if unicode.IsDigit(runes[i]) { + numberCount++ + } else { + specialCharCount++ + } + } + if mixedCaseCountStr, err := (*globalVars).GetGlobalSysVar(variable.ValidatePasswordMixedCaseCount); err != nil { + return "", err + } else if mixedCaseCount, err := strconv.ParseInt(mixedCaseCountStr, 10, 64); err != nil { + return "", err + } else if lowerCaseCount < mixedCaseCount { + return fmt.Sprintf("Require Password Lowercase Count: %d", mixedCaseCount), nil + } else if upperCaseCount < mixedCaseCount { + return fmt.Sprintf("Require Password Uppercase Count: %d", mixedCaseCount), nil + } + if requireNumberCountStr, err := (*globalVars).GetGlobalSysVar(variable.ValidatePasswordNumberCount); err != nil { + return "", err + } else if requireNumberCount, err := strconv.ParseInt(requireNumberCountStr, 10, 64); err != nil { + return "", err + } else if numberCount < requireNumberCount { + return fmt.Sprintf("Require Password Digit Count: %d", requireNumberCount), nil + } + if requireSpecialCharCountStr, err := (*globalVars).GetGlobalSysVar(variable.ValidatePasswordSpecialCharCount); err != nil { + return "", err + } else if requireSpecialCharCount, err := strconv.ParseInt(requireSpecialCharCountStr, 10, 64); err != nil { + return "", err + } else if specialCharCount < requireSpecialCharCount { + return fmt.Sprintf("Require Password Non-alphanumeric Count: %d", requireSpecialCharCount), nil + } + return "", nil +} + +// ValidatePassword checks whether the pwd can be used. +func ValidatePassword(sessionVars *variable.SessionVars, pwd string) error { + globalVars := sessionVars.GlobalVarsAccessor + + validatePolicy, err := globalVars.GetGlobalSysVar(variable.ValidatePasswordPolicy) + if err != nil { + return err + } + if warn, err := ValidateUserNameInPassword(pwd, sessionVars); err != nil { + return err + } else if len(warn) > 0 { + return variable.ErrNotValidPassword.GenWithStack(warn) + } + if warn, err := ValidatePasswordLowPolicy(pwd, &globalVars); err != nil { + return err + } else if len(warn) > 0 { + return variable.ErrNotValidPassword.GenWithStack(warn) + } + // LOW + if validatePolicy == "LOW" { + return nil + } + + // MEDIUM + if warn, err := ValidatePasswordMediumPolicy(pwd, &globalVars); err != nil { + return err + } else if len(warn) > 0 { + return variable.ErrNotValidPassword.GenWithStack(warn) + } + if validatePolicy == "MEDIUM" { + return nil + } + + // STRONG + if ok, err := ValidateDictionaryPassword(pwd, &globalVars); err != nil { + return err + } else if !ok { + return variable.ErrNotValidPassword.GenWithStack("Password contains word in the dictionary") + } + return nil +} diff --git a/util/password-validation/password_validation_test.go b/util/password-validation/password_validation_test.go new file mode 100644 index 0000000000000..323cba33ba409 --- /dev/null +++ b/util/password-validation/password_validation_test.go @@ -0,0 +1,137 @@ +// Copyright 2022 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package validator + +import ( + "context" + "testing" + + "github.com/pingcap/tidb/parser/auth" + "github.com/pingcap/tidb/sessionctx/variable" + "github.com/stretchr/testify/require" +) + +func TestValidateDictionaryPassword(t *testing.T) { + vars := variable.NewSessionVars(nil) + mock := variable.NewMockGlobalAccessor4Tests() + mock.SessionVars = vars + vars.GlobalVarsAccessor = mock + + err := mock.SetGlobalSysVar(context.Background(), variable.ValidatePasswordDictionary, "1234;5678;HIJK") + require.NoError(t, err) + testcases := []struct { + pwd string + result bool + }{ + {"abcdefg", true}, + {"abcd123efg", true}, + {"abcd1234efg", false}, + {"abcd12345efg", false}, + {"abcd123efghij", true}, + {"abcd123efghijk", false}, + } + for _, testcase := range testcases { + ok, err := ValidateDictionaryPassword(testcase.pwd, &vars.GlobalVarsAccessor) + require.NoError(t, err) + require.Equal(t, testcase.result, ok, testcase.pwd) + } +} + +func TestValidateUserNameInPassword(t *testing.T) { + sessionVars := variable.NewSessionVars(nil) + sessionVars.User = &auth.UserIdentity{Username: "user", AuthUsername: "authuser"} + sessionVars.GlobalVarsAccessor = variable.NewMockGlobalAccessor4Tests() + testcases := []struct { + pwd string + warn string + }{ + {"", ""}, + {"user", "Password Contains User Name"}, + {"authuser", "Password Contains User Name"}, + {"resu000", "Password Contains Reversed User Name"}, + {"resuhtua", "Password Contains Reversed User Name"}, + {"User", ""}, + {"authUser", ""}, + {"Resu", ""}, + {"Resuhtua", ""}, + } + // Enable check_user_name + err := sessionVars.GlobalVarsAccessor.SetGlobalSysVar(context.Background(), variable.ValidatePasswordCheckUserName, "ON") + require.NoError(t, err) + for _, testcase := range testcases { + warn, err := ValidateUserNameInPassword(testcase.pwd, sessionVars) + require.NoError(t, err) + require.Equal(t, testcase.warn, warn, testcase.pwd) + } + + // Disable check_user_name + err = sessionVars.GlobalVarsAccessor.SetGlobalSysVar(context.Background(), variable.ValidatePasswordCheckUserName, "OFF") + require.NoError(t, err) + for _, testcase := range testcases { + warn, err := ValidateUserNameInPassword(testcase.pwd, sessionVars) + require.NoError(t, err) + require.Equal(t, "", warn, testcase.pwd) + } +} + +func TestValidatePasswordLowPolicy(t *testing.T) { + sessionVars := variable.NewSessionVars(nil) + sessionVars.GlobalVarsAccessor = variable.NewMockGlobalAccessor4Tests() + sessionVars.GlobalVarsAccessor.(*variable.MockGlobalAccessor).SessionVars = sessionVars + err := sessionVars.GlobalVarsAccessor.SetGlobalSysVar(context.Background(), variable.ValidatePasswordLength, "8") + require.NoError(t, err) + + warn, err := ValidatePasswordLowPolicy("1234", &sessionVars.GlobalVarsAccessor) + require.NoError(t, err) + require.Equal(t, "Require Password Length: 8", warn) + warn, err = ValidatePasswordLowPolicy("12345678", &sessionVars.GlobalVarsAccessor) + require.NoError(t, err) + require.Equal(t, "", warn) + + err = sessionVars.GlobalVarsAccessor.SetGlobalSysVar(context.Background(), variable.ValidatePasswordLength, "12") + require.NoError(t, err) + warn, err = ValidatePasswordLowPolicy("12345678", &sessionVars.GlobalVarsAccessor) + require.NoError(t, err) + require.Equal(t, "Require Password Length: 12", warn) +} + +func TestValidatePasswordMediumPolicy(t *testing.T) { + sessionVars := variable.NewSessionVars(nil) + sessionVars.GlobalVarsAccessor = variable.NewMockGlobalAccessor4Tests() + sessionVars.GlobalVarsAccessor.(*variable.MockGlobalAccessor).SessionVars = sessionVars + + err := sessionVars.GlobalVarsAccessor.SetGlobalSysVar(context.Background(), variable.ValidatePasswordMixedCaseCount, "1") + require.NoError(t, err) + err = sessionVars.GlobalVarsAccessor.SetGlobalSysVar(context.Background(), variable.ValidatePasswordSpecialCharCount, "2") + require.NoError(t, err) + err = sessionVars.GlobalVarsAccessor.SetGlobalSysVar(context.Background(), variable.ValidatePasswordNumberCount, "3") + require.NoError(t, err) + + warn, err := ValidatePasswordMediumPolicy("!@A123", &sessionVars.GlobalVarsAccessor) + require.NoError(t, err) + require.Equal(t, "Require Password Lowercase Count: 1", warn) + warn, err = ValidatePasswordMediumPolicy("!@a123", &sessionVars.GlobalVarsAccessor) + require.NoError(t, err) + require.Equal(t, "Require Password Uppercase Count: 1", warn) + warn, err = ValidatePasswordMediumPolicy("!@Aa12", &sessionVars.GlobalVarsAccessor) + require.NoError(t, err) + require.Equal(t, "Require Password Digit Count: 3", warn) + warn, err = ValidatePasswordMediumPolicy("!Aa123", &sessionVars.GlobalVarsAccessor) + require.NoError(t, err) + require.Equal(t, "Require Password Non-alphanumeric Count: 2", warn) + warn, err = ValidatePasswordMediumPolicy("!@Aa123", &sessionVars.GlobalVarsAccessor) + require.NoError(t, err) + require.Equal(t, "", warn) +}