diff --git a/executor/errors.go b/executor/errors.go index a48152f0acdfe..89f97057c8afb 100644 --- a/executor/errors.go +++ b/executor/errors.go @@ -51,6 +51,7 @@ var ( ErrBadDB = terror.ClassExecutor.New(mysql.ErrBadDB, mysql.MySQLErrName[mysql.ErrBadDB]) ErrWrongObject = terror.ClassExecutor.New(mysql.ErrWrongObject, mysql.MySQLErrName[mysql.ErrWrongObject]) ErrRoleNotGranted = terror.ClassPrivilege.New(mysql.ErrRoleNotGranted, mysql.MySQLErrName[mysql.ErrRoleNotGranted]) + ErrNotValidPassword = terror.ClassExecutor.New(mysql.ErrNotValidPassword, mysql.MySQLErrName[mysql.ErrNotValidPassword]) ErrDeadlock = terror.ClassExecutor.New(mysql.ErrLockDeadlock, mysql.MySQLErrName[mysql.ErrLockDeadlock]) ErrQueryInterrupted = terror.ClassExecutor.New(mysql.ErrQueryInterrupted, mysql.MySQLErrName[mysql.ErrQueryInterrupted]) ) @@ -69,6 +70,7 @@ func init() { mysql.ErrTableaccessDenied: mysql.ErrTableaccessDenied, mysql.ErrBadDB: mysql.ErrBadDB, mysql.ErrWrongObject: mysql.ErrWrongObject, + mysql.ErrNotValidPassword: mysql.ErrNotValidPassword, mysql.ErrLockDeadlock: mysql.ErrLockDeadlock, mysql.ErrQueryInterrupted: mysql.ErrQueryInterrupted, } diff --git a/expression/builtin_encryption.go b/expression/builtin_encryption.go index 85083cc5cfeb3..de5c253b1566a 100644 --- a/expression/builtin_encryption.go +++ b/expression/builtin_encryption.go @@ -26,7 +26,10 @@ import ( "fmt" "hash" "io" + "strconv" "strings" + "unicode" + "unicode/utf8" "github.com/pingcap/errors" "github.com/pingcap/parser/auth" @@ -76,6 +79,15 @@ var ( // ivSize indicates the initialization vector supplied to aes_decrypt const ivSize = aes.BlockSize +// the max length of a password is 100 in mysql +const maxPwdLength int64 = 100 + +// VALIDATE_PASSWORD_STRENGTH() will return 0 when the length of a password is less than minPwsLength +const minPwdLength int64 = 4 + +// Differential score between levels in password test +const differentialScore int64 = 25 + // aesModeAttr indicates that the key length and iv attribute for specific block_encryption_mode. // keySize is the key length in bits and mode is the encryption mode. // ivRequired indicates that initialization vector is required or not. @@ -895,10 +907,160 @@ func (b *builtinUncompressedLengthSig) evalInt(row chunk.Row) (int64, bool, erro return int64(len), false, nil } +func reverse(s string) string { + runes := []rune(s) + for i, j := 0, len(runes)-1; i < j; i, j = i+1, j-1 { + runes[i], runes[j] = runes[j], runes[i] + } + return string(runes) +} + +func validateByUser(s *variable.SessionVars, psw string) (bool, error) { + v, err := variable.GetGlobalSystemVar(s, variable.ValidatePasswordCheckUserName) + if err != nil || strings.EqualFold(v, "OFF") { + return err == nil, err + } + if s.User == nil { + return true, nil + } + if n := s.User.Username; n != "" { + if psw == n || psw == reverse(n) { + return false, nil + } + } + if n := s.User.AuthUsername; n != "" { + if psw == n || psw == reverse(n) { + return false, nil + } + } + return true, nil +} + +func validateByMixedDigitSpecial(s *variable.SessionVars, pwd string) (bool, error) { + // get system variables + v, err := variable.GetGlobalSystemVar(s, variable.ValidatePasswordMixedCaseCount) + if err != nil { + return false, err + } + minMixed, err := strconv.ParseInt(v, 10, 64) + if err != nil { + return false, err + } + v, err = variable.GetGlobalSystemVar(s, variable.ValidatePasswordNumberCount) + if err != nil { + return false, err + } + minDigit, err := strconv.ParseInt(v, 10, 64) + if err != nil { + return false, err + } + v, err = variable.GetGlobalSystemVar(s, variable.ValidatePasswordSpecialCharCount) + if err != nil { + return false, err + } + minSpecial, err := strconv.ParseInt(v, 10, 64) + if err != nil { + return false, err + } + + numLower := int64(0) + numUpper := int64(0) + numDigit := int64(0) + numSpecial := int64(0) + for _, c := range pwd { + if unicode.IsLower(c) { + numLower++ + } else if unicode.IsUpper(c) { + numUpper++ + } else if unicode.IsDigit(c) { + numDigit++ + } else { + numSpecial++ + } + } + + if numLower < minMixed || numUpper < minMixed || numDigit < minDigit || numSpecial < minSpecial { + return false, nil + } + return true, nil +} + +// TODO: Support validating password by dictionary file +func validateByDictionary(s *variable.SessionVars, pwd string) (bool, error) { + return true, nil +} + type validatePasswordStrengthFunctionClass struct { baseFunctionClass } func (c *validatePasswordStrengthFunctionClass) getFunction(ctx sessionctx.Context, args []Expression) (builtinFunc, error) { - return nil, errFunctionNotExists.GenWithStackByArgs("FUNCTION", "VALIDATE_PASSWORD_STRENGTH") + if err := c.verifyArgs(args); err != nil { + return nil, err + } + bf := newBaseBuiltinFuncWithTp(ctx, args, types.ETInt, types.ETString) + bf.tp.Flag = mysql.MaxIntWidth + sig := &builtinValidatePasswordStrengthSig{bf} + return sig, nil +} + +type builtinValidatePasswordStrengthSig struct { + baseBuiltinFunc +} + +func (c *builtinValidatePasswordStrengthSig) Clone() builtinFunc { + newSig := &builtinValidatePasswordStrengthSig{} + newSig.cloneFrom(&c.baseBuiltinFunc) + return newSig +} + +// evalInt evals VALIDATE_PASSWORD_STRENGTH(str). +// See https://dev.mysql.com/doc/refman/5.7/en/encryption-functions.html#function_validate-password-strength +func (c *builtinValidatePasswordStrengthSig) evalInt(row chunk.Row) (int64, bool, error) { + sv := c.ctx.GetSessionVars() + pwd, isNull, err := c.args[0].EvalString(c.ctx, row) + score := int64(0) + + if isNull || err != nil { + return score, true, err + } + + l := int64(utf8.RuneCountInString(pwd)) + if l > maxPwdLength { + pwd = string([]rune(pwd)[0:maxPwdLength]) + } + + valid, err := validateByUser(sv, pwd) + if err != nil || !valid { + return score, err != nil, err + } + + if l < minPwdLength { + return score, false, nil + } + score += differentialScore + + v, err := variable.GetGlobalSystemVar(sv, variable.ValidatePasswordLength) + if err != nil { + return score, false, nil + } + valPwdLen, err := strconv.ParseInt(v, 10, 64) + if err != nil || l < valPwdLen { + return score, err != nil, err + } + score += differentialScore + + valid, err = validateByMixedDigitSpecial(sv, pwd) + if err != nil || !valid { + return score, err != nil, err + } + score += differentialScore + + valid, err = validateByDictionary(sv, pwd) + if err != nil || !valid { + return score, err != nil, nil + } + score += differentialScore + + return score, false, nil } diff --git a/expression/builtin_encryption_test.go b/expression/builtin_encryption_test.go index 4a12f14f1a9af..10b710ba3a21d 100644 --- a/expression/builtin_encryption_test.go +++ b/expression/builtin_encryption_test.go @@ -493,3 +493,35 @@ func (s *testEvaluatorSuite) TestPassword(c *C) { _, err := funcs[ast.PasswordFunc].getFunction(s.ctx, []Expression{Zero}) c.Assert(err, IsNil) } + +func (s *testEvaluatorSuite) TestValidatePasswordStrength(c *C) { + defer testleak.AfterTest(c)() + tests := []struct { + in interface{} + expect int64 + }{ + {"", 0}, + {"1", 0}, + {"你好啊", 0}, + {"pass", 25}, + {"user123", 25}, + {"你好世界", 25}, + {"password", 50}, + {"password0000", 50}, + {"auth_user", 50}, + {"你好世界你好世界", 50}, + {"Pingcap123", 50}, + {"Pingcap123_", 100}, + {"password1A#", 100}, + {"PA12wrd!#", 100}, + {"PA00wrd!#", 100}, + } + + for _, t := range tests { + f, err := newFunctionForTest(s.ctx, ast.ValidatePasswordStrength, s.primitiveValsToConstants([]interface{}{t.in})...) + c.Assert(err, IsNil) + d, err := f.Eval(chunk.Row{}) + c.Assert(err, IsNil) + c.Assert(d.GetInt64(), Equals, t.expect) + } +} diff --git a/expression/integration_test.go b/expression/integration_test.go index 59bdae56242d7..81ba83e4955e3 100644 --- a/expression/integration_test.go +++ b/expression/integration_test.go @@ -4368,6 +4368,65 @@ func (s *testIntegrationSuite) TestIssue9325(c *C) { result.Check(testkit.Rows("2019-02-16 14:19:59", "2019-02-16 14:20:01")) } +func (s *testIntegrationSuite) TestFuncValidatePasswordStrength(c *C) { + tk := testkit.NewTestKit(c, s.store) + tk.MustExec("set global validate_password_check_user_name = 'ON';") + tk.MustQuery("SELECT @@global.validate_password_check_user_name;").Check(testkit.Rows("1")) + err := tk.ExecToErr("SET @@session.validate_password_check_user_name= ON;") + c.Assert(err, ErrorMatches, "*Variable 'validate_password_check_user_name' is a GLOBAL variable and should be set with SET GLOBAL") + err = tk.ExecToErr("SET validate_password_check_user_name= ON;") + c.Assert(err, ErrorMatches, "*Variable 'validate_password_check_user_name' is a GLOBAL variable and should be set with SET GLOBAL") + tk.MustExec("SET @@global.validate_password_policy=LOW;") + tk.MustExec("SET @@global.validate_password_mixed_case_count=0;") + tk.MustExec("SET @@global.validate_password_number_count=0;") + tk.MustExec("SET @@global.validate_password_special_char_count=0;") + tk.MustExec("SET @@global.validate_password_length=0;") + tk.MustExec("SET @@global.validate_password_check_user_name= ON;") + tk.Se.GetSessionVars().User = &auth.UserIdentity{Username: "root", AuthUsername: "root"} + result := tk.MustQuery("SELECT VALIDATE_PASSWORD_STRENGTH('root') = 0;") + result.Check(testkit.Rows("1")) + result = tk.MustQuery("SELECT VALIDATE_PASSWORD_STRENGTH('toor') = 0;") + result.Check(testkit.Rows("1")) + result = tk.MustQuery("SELECT VALIDATE_PASSWORD_STRENGTH('Root') <> 0;") + result.Check(testkit.Rows("1")) + result = tk.MustQuery("SELECT VALIDATE_PASSWORD_STRENGTH('Toor') <> 0;") + result.Check(testkit.Rows("1")) + result = tk.MustQuery("SELECT VALIDATE_PASSWORD_STRENGTH('fooHoHo%1') <> 0;") + result.Check(testkit.Rows("1")) + + err = tk.ExecToErr("SELECT VALIDATE_PASSWORD_STRENGTH('password', 0);") + c.Check(err, ErrorMatches, "*Incorrect parameter count in the call to native function 'validate_password_strength'") + err = tk.ExecToErr("SELECT VALIDATE_PASSWORD_STRENGTH();") + c.Check(err, ErrorMatches, "*Incorrect parameter count in the call to native function 'validate_password_strength'") + + tk.MustExec("set global validate_password_length = 10;") + tk.MustExec("set global validate_password_mixed_case_count = 3;") + tk.MustExec("set global validate_password_number_count = 3;") + tk.MustExec("set global validate_password_special_char_count = 3;") + tk.Se.GetSessionVars().User = &auth.UserIdentity{Username: "user123", AuthUsername: "auth_user"} + result = tk.MustQuery("select VALIDATE_PASSWORD_STRENGTH('user123')") + result.Check(testkit.Rows("0")) + result = tk.MustQuery("select VALIDATE_PASSWORD_STRENGTH('321resu')") + result.Check(testkit.Rows("0")) + result = tk.MustQuery("select VALIDATE_PASSWORD_STRENGTH('auth_user')") + result.Check(testkit.Rows("0")) + result = tk.MustQuery("select VALIDATE_PASSWORD_STRENGTH('resu_htua')") + result.Check(testkit.Rows("0")) + result = tk.MustQuery("select VALIDATE_PASSWORD_STRENGTH('password')") + result.Check(testkit.Rows("25")) + result = tk.MustQuery("select VALIDATE_PASSWORD_STRENGTH('password12')") + result.Check(testkit.Rows("50")) + result = tk.MustQuery("select VALIDATE_PASSWORD_STRENGTH('Pingcap123_')") + result.Check(testkit.Rows("50")) + result = tk.MustQuery("select VALIDATE_PASSWORD_STRENGTH('PingcapPP123_')") + result.Check(testkit.Rows("50")) + result = tk.MustQuery("select VALIDATE_PASSWORD_STRENGTH('PINGCAp123_')") + result.Check(testkit.Rows("50")) + result = tk.MustQuery("select VALIDATE_PASSWORD_STRENGTH('Pingcap_PP_123!')") + result.Check(testkit.Rows("100")) + +} + func (s *testIntegrationSuite) TestIssue9710(c *C) { tk := testkit.NewTestKit(c, s.store) getSAndMS := func(str string) (int, int) { diff --git a/sessionctx/variable/sysvar.go b/sessionctx/variable/sysvar.go index 9d8ade09e4748..1317df443e5f5 100644 --- a/sessionctx/variable/sysvar.go +++ b/sessionctx/variable/sysvar.go @@ -135,8 +135,14 @@ var defaultSysVars = []*SysVar{ {ScopeGlobal, "slave_pending_jobs_size_max", "16777216"}, {ScopeNone, "innodb_sync_array_size", "1"}, {ScopeSession, "rand_seed2", ""}, + {ScopeGlobal, "validate_password_check_user_name", "OFF"}, + {ScopeGlobal, "validate_password_dictionary_file", ""}, + {ScopeGlobal, "validate_password_length", "8"}, + {ScopeGlobal, "validate_password_mixed_case_count", "1"}, {ScopeGlobal, ValidatePasswordCheckUserName, "0"}, {ScopeGlobal, "validate_password_number_count", "1"}, + {ScopeGlobal, "validate_password_policy", "MEDIUM"}, + {ScopeGlobal, "validate_password_special_char_count", "1"}, {ScopeSession, "gtid_next", ""}, {ScopeGlobal | ScopeSession, SQLSelectLimit, "18446744073709551615"}, {ScopeGlobal, "ndb_show_foreign_key_mock_tables", ""}, @@ -606,7 +612,6 @@ var defaultSysVars = []*SysVar{ {ScopeGlobal, "sync_relay_log_info", "10000"}, {ScopeGlobal | ScopeSession, "optimizer_trace_limit", "1"}, {ScopeNone, "innodb_ft_max_token_size", "84"}, - {ScopeGlobal, "validate_password_length", "8"}, {ScopeGlobal, "ndb_log_binlog_index", ""}, {ScopeGlobal, "innodb_api_bk_commit_interval", "5"}, {ScopeNone, "innodb_undo_directory", "."}, @@ -815,10 +820,20 @@ const ( BlockEncryptionMode = "block_encryption_mode" // WaitTimeout is the name for 'wait_timeout' system variable. WaitTimeout = "wait_timeout" - // ValidatePasswordNumberCount is the name of 'validate_password_number_count' system variable. - ValidatePasswordNumberCount = "validate_password_number_count" + // ValidatePasswordCheckUserName is the name of 'validate_password_check_user_name' system variable + ValidatePasswordCheckUserName = "validate_password_check_user_name" + // ValidatePasswordDictionaryFile is the name of 'validate_password_dictionary_file' system variable + ValidatePasswordDictionaryFile = "validate_password_dictionary_file" // ValidatePasswordLength is the name of 'validate_password_length' system variable. ValidatePasswordLength = "validate_password_length" + // ValidatePasswordMixedCaseCount is the name of 'validate_password_mixed_case_count' system variable + ValidatePasswordMixedCaseCount = "validate_password_mixed_case_count" + // ValidatePasswordNumberCount is the name of 'validate_password_number_count' system variable. + ValidatePasswordNumberCount = "validate_password_number_count" + // ValidatePasswordPolicy is the name of 'validate_password_policy' system variable + ValidatePasswordPolicy = "validate_password_policy" + // ValidatePasswordSpecialCharCount is the name of 'validate_password_special_char_count' system variable + ValidatePasswordSpecialCharCount = "validate_password_special_char_count" // PluginDir is the name of 'plugin_dir' system variable. PluginDir = "plugin_dir" // PluginLoad is the name of 'plugin_load' system variable. @@ -835,8 +850,6 @@ const ( BinlogOrderCommits = "binlog_order_commits" // MasterVerifyChecksum is the name for 'master_verify_checksum' system variable. MasterVerifyChecksum = "master_verify_checksum" - // ValidatePasswordCheckUserName is the name for 'validate_password_check_user_name' system variable. - ValidatePasswordCheckUserName = "validate_password_check_user_name" // SuperReadOnly is the name for 'super_read_only' system variable. SuperReadOnly = "super_read_only" // SQLNotes is the name for 'sql_notes' system variable. diff --git a/sessionctx/variable/varsutil.go b/sessionctx/variable/varsutil.go index ae244d93face2..e7d2feeac5375 100644 --- a/sessionctx/variable/varsutil.go +++ b/sessionctx/variable/varsutil.go @@ -347,8 +347,17 @@ func ValidateSetSystemVar(vars *SessionVars, name string, value string) (string, } _, err := parseTimeZone(value) return value, err - case ValidatePasswordLength, ValidatePasswordNumberCount: - return checkUInt64SystemVar(name, value, 0, math.MaxUint64, vars) + case ValidatePasswordLength, ValidatePasswordMixedCaseCount, ValidatePasswordNumberCount, ValidatePasswordSpecialCharCount: + return checkUInt64SystemVar(name, value, 0, 100, vars) + case ValidatePasswordPolicy: + if strings.EqualFold(value, "LOW") || value == "0" { + return "0", nil + } else if strings.EqualFold(value, "MEDIUM") || value == "1" { + return "1", nil + } else if strings.EqualFold(value, "STRONG") || value == "2" { + return "2", nil + } + return value, ErrWrongValueForVar.GenWithStackByArgs(name, value) case WarningCount, ErrorCount: return value, ErrReadOnly.GenWithStackByArgs(name) case EnforceGtidConsistency: