diff --git a/cmd/explaintest/r/new_character_set_builtin.result b/cmd/explaintest/r/new_character_set_builtin.result index a13ddf6722584..bb4e9feaae201 100644 --- a/cmd/explaintest/r/new_character_set_builtin.result +++ b/cmd/explaintest/r/new_character_set_builtin.result @@ -141,3 +141,14 @@ select md5(a), md5(b), md5(c) from t; md5(a) md5(b) md5(c) 8093a32450075324682d01456d6e3919 a45d4af7b243e7f393fa09bed72ac73e aae0117857fe54811a5239275dd81133 set @@tidb_enable_vectorized_expression = false; +drop table if exists t; +create table t (a char(20) charset utf8mb4, b char(20) charset gbk, c binary(20)); +insert into t values ('一二三', '一二三', '一二三'); +select decode(encode(a,"monty"),"monty") = a, md5(decode(encode(b,"monty"),"monty")) = md5(b), decode(encode(c,"monty"),"monty") = c from t; +decode(encode(a,"monty"),"monty") = a md5(decode(encode(b,"monty"),"monty")) = md5(b) decode(encode(c,"monty"),"monty") = c +1 1 1 +set @@tidb_enable_vectorized_expression = true; +select decode(encode(a,"monty"),"monty") = a, md5(decode(encode(b,"monty"),"monty")) = md5(b), decode(encode(c,"monty"),"monty") = c from t; +decode(encode(a,"monty"),"monty") = a md5(decode(encode(b,"monty"),"monty")) = md5(b) decode(encode(c,"monty"),"monty") = c +1 1 1 +set @@tidb_enable_vectorized_expression = false; diff --git a/cmd/explaintest/t/new_character_set_builtin.test b/cmd/explaintest/t/new_character_set_builtin.test index 04ee066344d11..d5d0bcc9a14f5 100644 --- a/cmd/explaintest/t/new_character_set_builtin.test +++ b/cmd/explaintest/t/new_character_set_builtin.test @@ -66,3 +66,12 @@ select md5(a), md5(b), md5(c) from t; set @@tidb_enable_vectorized_expression = true; select md5(a), md5(b), md5(c) from t; set @@tidb_enable_vectorized_expression = false; + +-- test for builtin function decode()/encode() +drop table if exists t; +create table t (a char(20) charset utf8mb4, b char(20) charset gbk, c binary(20)); +insert into t values ('一二三', '一二三', '一二三'); +select decode(encode(a,"monty"),"monty") = a, md5(decode(encode(b,"monty"),"monty")) = md5(b), decode(encode(c,"monty"),"monty") = c from t; +set @@tidb_enable_vectorized_expression = true; +select decode(encode(a,"monty"),"monty") = a, md5(decode(encode(b,"monty"),"monty")) = md5(b), decode(encode(c,"monty"),"monty") = c from t; +set @@tidb_enable_vectorized_expression = false; diff --git a/expression/builtin_encryption.go b/expression/builtin_encryption.go index 3267033421071..47a1d717b5f37 100644 --- a/expression/builtin_encryption.go +++ b/expression/builtin_encryption.go @@ -407,11 +407,21 @@ func (b *builtinDecodeSig) evalString(row chunk.Row) (string, bool, error) { if isNull || err != nil { return "", true, err } + dataTp := b.args[0].GetType() + dataStr, err = charset.NewEncoding(dataTp.Charset).EncodeString(dataStr) + if err != nil { + return "", false, err + } passwordStr, isNull, err := b.args[1].EvalString(b.ctx, row) if isNull || err != nil { return "", true, err } + passwordTp := b.args[1].GetType() + passwordStr, err = charset.NewEncoding(passwordTp.Charset).EncodeString(passwordStr) + if err != nil { + return "", false, err + } decodeStr, err := encrypt.SQLDecode(dataStr, passwordStr) return decodeStr, false, err @@ -470,11 +480,21 @@ func (b *builtinEncodeSig) evalString(row chunk.Row) (string, bool, error) { if isNull || err != nil { return "", true, err } + decodeTp := b.args[0].GetType() + decodeStr, err = charset.NewEncoding(decodeTp.Charset).EncodeString(decodeStr) + if err != nil { + return "", false, err + } passwordStr, isNull, err := b.args[1].EvalString(b.ctx, row) if isNull || err != nil { return "", true, err } + passwordTp := b.args[1].GetType() + passwordStr, err = charset.NewEncoding(passwordTp.Charset).EncodeString(passwordStr) + if err != nil { + return "", false, err + } dataStr, err := encrypt.SQLEncode(decodeStr, passwordStr) return dataStr, false, err diff --git a/expression/builtin_encryption_test.go b/expression/builtin_encryption_test.go index 505729a462d09..1224cec517fce 100644 --- a/expression/builtin_encryption_test.go +++ b/expression/builtin_encryption_test.go @@ -33,37 +33,44 @@ import ( ) var cryptTests = []struct { + chs string origin interface{} password interface{} crypt interface{} }{ - {"", "", ""}, - {"pingcap", "1234567890123456", "2C35B5A4ADF391"}, - {"pingcap", "asdfjasfwefjfjkj", "351CC412605905"}, - {"pingcap123", "123456789012345678901234", "7698723DC6DFE7724221"}, - {"pingcap#%$%^", "*^%YTu1234567", "8634B9C55FF55E5B6328F449"}, - {"pingcap", "", "4A77B524BD2C5C"}, - {"分布式データベース", "pass1234@#$%%^^&", "80CADC8D328B3026D04FB285F36FED04BBCA0CC685BF78B1E687CE"}, - {"分布式データベース", "分布式7782734adgwy1242", "0E24CFEF272EE32B6E0BFBDB89F29FB43B4B30DAA95C3F914444BC"}, - {"pingcap", "密匙", "CE5C02A5010010"}, - {"pingcap数据库", "数据库passwd12345667", "36D5F90D3834E30E396BE3226E3B4ED3"}, - {"数据库5667", 123.435, "B22196D0569386237AE12F8AAB"}, - {nil, "数据库passwd12345667", nil}, + {mysql.DefaultCollationName, "", "", ""}, + {mysql.DefaultCollationName, "pingcap", "1234567890123456", "2C35B5A4ADF391"}, + {mysql.DefaultCollationName, "pingcap", "asdfjasfwefjfjkj", "351CC412605905"}, + {mysql.DefaultCollationName, "pingcap123", "123456789012345678901234", "7698723DC6DFE7724221"}, + {mysql.DefaultCollationName, "pingcap#%$%^", "*^%YTu1234567", "8634B9C55FF55E5B6328F449"}, + {mysql.DefaultCollationName, "pingcap", "", "4A77B524BD2C5C"}, + {mysql.DefaultCollationName, "分布式データベース", "pass1234@#$%%^^&", "80CADC8D328B3026D04FB285F36FED04BBCA0CC685BF78B1E687CE"}, + {mysql.DefaultCollationName, "分布式データベース", "分布式7782734adgwy1242", "0E24CFEF272EE32B6E0BFBDB89F29FB43B4B30DAA95C3F914444BC"}, + {mysql.DefaultCollationName, "pingcap", "密匙", "CE5C02A5010010"}, + {"gbk", "pingcap", "密匙", "E407AC6F691ADE"}, + {mysql.DefaultCollationName, "pingcap数据库", "数据库passwd12345667", "36D5F90D3834E30E396BE3226E3B4ED3"}, + {"gbk", "pingcap数据库", "数据库passwd12345667", "B4BDBD6EC8346379F42836E2E0"}, + {mysql.DefaultCollationName, "数据库5667", 123.435, "B22196D0569386237AE12F8AAB"}, + {"gbk", "数据库5667", 123.435, "79E22979BD860EF58229"}, + {mysql.DefaultCollationName, nil, "数据库passwd12345667", nil}, } func TestSQLDecode(t *testing.T) { t.Parallel() ctx := createContext(t) - fc := funcs[ast.Decode] for _, tt := range cryptTests { - str := types.NewDatum(tt.origin) - password := types.NewDatum(tt.password) - - f, err := fc.getFunction(ctx, datumsToConstants([]types.Datum{str, password})) + err := ctx.GetSessionVars().SetSystemVar(variable.CharacterSetConnection, tt.chs) require.NoError(t, err) - crypt, err := evalBuiltinFunc(f, chunk.Row{}) + err = ctx.GetSessionVars().SetSystemVar(variable.CollationConnection, tt.chs) require.NoError(t, err) - require.Equal(t, types.NewDatum(tt.crypt), toHex(crypt)) + f, err := newFunctionForTest(ctx, ast.Decode, primitiveValsToConstants(ctx, []interface{}{tt.origin, tt.password})...) + require.NoError(t, err) + d, err := f.Eval(chunk.Row{}) + require.NoError(t, err) + if !d.IsNull() { + d = toHex(d) + } + require.Equal(t, types.NewDatum(tt.crypt), d) } testNullInput(t, ctx, ast.Decode) } @@ -71,18 +78,29 @@ func TestSQLDecode(t *testing.T) { func TestSQLEncode(t *testing.T) { t.Parallel() ctx := createContext(t) - - fc := funcs[ast.Encode] for _, test := range cryptTests { - password := types.NewDatum(test.password) - cryptStr := fromHex(test.crypt) - - f, err := fc.getFunction(ctx, datumsToConstants([]types.Datum{cryptStr, password})) + err := ctx.GetSessionVars().SetSystemVar(variable.CharacterSetConnection, test.chs) require.NoError(t, err) - str, err := evalBuiltinFunc(f, chunk.Row{}) - + err = ctx.GetSessionVars().SetSystemVar(variable.CollationConnection, test.chs) + require.NoError(t, err) + var h []byte + if test.crypt != nil { + h, _ = hex.DecodeString(test.crypt.(string)) + } else { + h = nil + } + f, err := newFunctionForTest(ctx, ast.Encode, primitiveValsToConstants(ctx, []interface{}{h, test.password})...) require.NoError(t, err) - require.Equal(t, types.NewDatum(test.origin), str) + d, err := f.Eval(chunk.Row{}) + require.NoError(t, err) + if test.origin != nil { + result, err := charset.NewEncoding(test.chs).EncodeString(test.origin.(string)) + require.NoError(t, err) + require.Equal(t, types.NewCollationStringDatum(result, test.chs, 1), d) + } else { + result := types.NewDatum(test.origin) + require.Equal(t, result.GetBytes(), d.GetBytes()) + } } testNullInput(t, ctx, ast.Encode) } diff --git a/expression/builtin_encryption_vec.go b/expression/builtin_encryption_vec.go index b7d9cfc785294..b9c0ff43ddd1e 100644 --- a/expression/builtin_encryption_vec.go +++ b/expression/builtin_encryption_vec.go @@ -212,6 +212,9 @@ func (b *builtinDecodeSig) vecEvalString(input *chunk.Chunk, result *chunk.Colum if err := b.args[0].VecEvalString(b.ctx, input, buf); err != nil { return err } + dataTp := b.args[0].GetType() + dataEnc := charset.NewEncoding(dataTp.Charset) + buf1, err1 := b.bufAllocator.get() if err1 != nil { return err1 @@ -220,14 +223,22 @@ func (b *builtinDecodeSig) vecEvalString(input *chunk.Chunk, result *chunk.Colum if err := b.args[1].VecEvalString(b.ctx, input, buf1); err != nil { return err } + passwordTp := b.args[1].GetType() + passwordEnc := charset.NewEncoding(passwordTp.Charset) result.ReserveString(n) for i := 0; i < n; i++ { if buf.IsNull(i) || buf1.IsNull(i) { result.AppendNull() continue } - dataStr := buf.GetString(i) - passwordStr := buf1.GetString(i) + dataStr, err := dataEnc.EncodeString(buf.GetString(i)) + if err != nil { + return err + } + passwordStr, err := passwordEnc.EncodeString(buf1.GetString(i)) + if err != nil { + return err + } decodeStr, err := encrypt.SQLDecode(dataStr, passwordStr) if err != nil { return err @@ -255,18 +266,29 @@ func (b *builtinEncodeSig) vecEvalString(input *chunk.Chunk, result *chunk.Colum if err1 != nil { return err1 } + dataTp := b.args[0].GetType() + dataEnc := charset.NewEncoding(dataTp.Charset) defer b.bufAllocator.put(buf1) if err := b.args[1].VecEvalString(b.ctx, input, buf1); err != nil { return err } + passwordTp := b.args[1].GetType() + passwordEnc := charset.NewEncoding(passwordTp.Charset) result.ReserveString(n) for i := 0; i < n; i++ { if buf.IsNull(i) || buf1.IsNull(i) { result.AppendNull() continue } - decodeStr := buf.GetString(i) - passwordStr := buf1.GetString(i) + + decodeStr, err := dataEnc.EncodeString(buf.GetString(i)) + if err != nil { + return err + } + passwordStr, err := passwordEnc.EncodeString(buf1.GetString(i)) + if err != nil { + return err + } dataStr, err := encrypt.SQLEncode(decodeStr, passwordStr) if err != nil { return err