Skip to content

Commit

Permalink
expression: support GBK for builtin function Decode and Encode (#29315)
Browse files Browse the repository at this point in the history
  • Loading branch information
hawkingrei authored Nov 3, 2021
1 parent 731902d commit 9d9915b
Show file tree
Hide file tree
Showing 5 changed files with 112 additions and 32 deletions.
11 changes: 11 additions & 0 deletions cmd/explaintest/r/new_character_set_builtin.result
Original file line number Diff line number Diff line change
Expand Up @@ -141,3 +141,14 @@ select md5(a), md5(b), md5(c) from t;
md5(a) md5(b) md5(c)
8093a32450075324682d01456d6e3919 a45d4af7b243e7f393fa09bed72ac73e aae0117857fe54811a5239275dd81133
set @@tidb_enable_vectorized_expression = false;
drop table if exists t;
create table t (a char(20) charset utf8mb4, b char(20) charset gbk, c binary(20));
insert into t values ('一二三', '一二三', '一二三');
select decode(encode(a,"monty"),"monty") = a, md5(decode(encode(b,"monty"),"monty")) = md5(b), decode(encode(c,"monty"),"monty") = c from t;
decode(encode(a,"monty"),"monty") = a md5(decode(encode(b,"monty"),"monty")) = md5(b) decode(encode(c,"monty"),"monty") = c
1 1 1
set @@tidb_enable_vectorized_expression = true;
select decode(encode(a,"monty"),"monty") = a, md5(decode(encode(b,"monty"),"monty")) = md5(b), decode(encode(c,"monty"),"monty") = c from t;
decode(encode(a,"monty"),"monty") = a md5(decode(encode(b,"monty"),"monty")) = md5(b) decode(encode(c,"monty"),"monty") = c
1 1 1
set @@tidb_enable_vectorized_expression = false;
9 changes: 9 additions & 0 deletions cmd/explaintest/t/new_character_set_builtin.test
Original file line number Diff line number Diff line change
Expand Up @@ -66,3 +66,12 @@ select md5(a), md5(b), md5(c) from t;
set @@tidb_enable_vectorized_expression = true;
select md5(a), md5(b), md5(c) from t;
set @@tidb_enable_vectorized_expression = false;

-- test for builtin function decode()/encode()
drop table if exists t;
create table t (a char(20) charset utf8mb4, b char(20) charset gbk, c binary(20));
insert into t values ('一二三', '一二三', '一二三');
select decode(encode(a,"monty"),"monty") = a, md5(decode(encode(b,"monty"),"monty")) = md5(b), decode(encode(c,"monty"),"monty") = c from t;
set @@tidb_enable_vectorized_expression = true;
select decode(encode(a,"monty"),"monty") = a, md5(decode(encode(b,"monty"),"monty")) = md5(b), decode(encode(c,"monty"),"monty") = c from t;
set @@tidb_enable_vectorized_expression = false;
20 changes: 20 additions & 0 deletions expression/builtin_encryption.go
Original file line number Diff line number Diff line change
Expand Up @@ -407,11 +407,21 @@ func (b *builtinDecodeSig) evalString(row chunk.Row) (string, bool, error) {
if isNull || err != nil {
return "", true, err
}
dataTp := b.args[0].GetType()
dataStr, err = charset.NewEncoding(dataTp.Charset).EncodeString(dataStr)
if err != nil {
return "", false, err
}

passwordStr, isNull, err := b.args[1].EvalString(b.ctx, row)
if isNull || err != nil {
return "", true, err
}
passwordTp := b.args[1].GetType()
passwordStr, err = charset.NewEncoding(passwordTp.Charset).EncodeString(passwordStr)
if err != nil {
return "", false, err
}

decodeStr, err := encrypt.SQLDecode(dataStr, passwordStr)
return decodeStr, false, err
Expand Down Expand Up @@ -470,11 +480,21 @@ func (b *builtinEncodeSig) evalString(row chunk.Row) (string, bool, error) {
if isNull || err != nil {
return "", true, err
}
decodeTp := b.args[0].GetType()
decodeStr, err = charset.NewEncoding(decodeTp.Charset).EncodeString(decodeStr)
if err != nil {
return "", false, err
}

passwordStr, isNull, err := b.args[1].EvalString(b.ctx, row)
if isNull || err != nil {
return "", true, err
}
passwordTp := b.args[1].GetType()
passwordStr, err = charset.NewEncoding(passwordTp.Charset).EncodeString(passwordStr)
if err != nil {
return "", false, err
}

dataStr, err := encrypt.SQLEncode(decodeStr, passwordStr)
return dataStr, false, err
Expand Down
74 changes: 46 additions & 28 deletions expression/builtin_encryption_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,56 +33,74 @@ import (
)

var cryptTests = []struct {
chs string
origin interface{}
password interface{}
crypt interface{}
}{
{"", "", ""},
{"pingcap", "1234567890123456", "2C35B5A4ADF391"},
{"pingcap", "asdfjasfwefjfjkj", "351CC412605905"},
{"pingcap123", "123456789012345678901234", "7698723DC6DFE7724221"},
{"pingcap#%$%^", "*^%YTu1234567", "8634B9C55FF55E5B6328F449"},
{"pingcap", "", "4A77B524BD2C5C"},
{"分布式データベース", "pass1234@#$%%^^&", "80CADC8D328B3026D04FB285F36FED04BBCA0CC685BF78B1E687CE"},
{"分布式データベース", "分布式7782734adgwy1242", "0E24CFEF272EE32B6E0BFBDB89F29FB43B4B30DAA95C3F914444BC"},
{"pingcap", "密匙", "CE5C02A5010010"},
{"pingcap数据库", "数据库passwd12345667", "36D5F90D3834E30E396BE3226E3B4ED3"},
{"数据库5667", 123.435, "B22196D0569386237AE12F8AAB"},
{nil, "数据库passwd12345667", nil},
{mysql.DefaultCollationName, "", "", ""},
{mysql.DefaultCollationName, "pingcap", "1234567890123456", "2C35B5A4ADF391"},
{mysql.DefaultCollationName, "pingcap", "asdfjasfwefjfjkj", "351CC412605905"},
{mysql.DefaultCollationName, "pingcap123", "123456789012345678901234", "7698723DC6DFE7724221"},
{mysql.DefaultCollationName, "pingcap#%$%^", "*^%YTu1234567", "8634B9C55FF55E5B6328F449"},
{mysql.DefaultCollationName, "pingcap", "", "4A77B524BD2C5C"},
{mysql.DefaultCollationName, "分布式データベース", "pass1234@#$%%^^&", "80CADC8D328B3026D04FB285F36FED04BBCA0CC685BF78B1E687CE"},
{mysql.DefaultCollationName, "分布式データベース", "分布式7782734adgwy1242", "0E24CFEF272EE32B6E0BFBDB89F29FB43B4B30DAA95C3F914444BC"},
{mysql.DefaultCollationName, "pingcap", "密匙", "CE5C02A5010010"},
{"gbk", "pingcap", "密匙", "E407AC6F691ADE"},
{mysql.DefaultCollationName, "pingcap数据库", "数据库passwd12345667", "36D5F90D3834E30E396BE3226E3B4ED3"},
{"gbk", "pingcap数据库", "数据库passwd12345667", "B4BDBD6EC8346379F42836E2E0"},
{mysql.DefaultCollationName, "数据库5667", 123.435, "B22196D0569386237AE12F8AAB"},
{"gbk", "数据库5667", 123.435, "79E22979BD860EF58229"},
{mysql.DefaultCollationName, nil, "数据库passwd12345667", nil},
}

func TestSQLDecode(t *testing.T) {
t.Parallel()
ctx := createContext(t)
fc := funcs[ast.Decode]
for _, tt := range cryptTests {
str := types.NewDatum(tt.origin)
password := types.NewDatum(tt.password)

f, err := fc.getFunction(ctx, datumsToConstants([]types.Datum{str, password}))
err := ctx.GetSessionVars().SetSystemVar(variable.CharacterSetConnection, tt.chs)
require.NoError(t, err)
crypt, err := evalBuiltinFunc(f, chunk.Row{})
err = ctx.GetSessionVars().SetSystemVar(variable.CollationConnection, tt.chs)
require.NoError(t, err)
require.Equal(t, types.NewDatum(tt.crypt), toHex(crypt))
f, err := newFunctionForTest(ctx, ast.Decode, primitiveValsToConstants(ctx, []interface{}{tt.origin, tt.password})...)
require.NoError(t, err)
d, err := f.Eval(chunk.Row{})
require.NoError(t, err)
if !d.IsNull() {
d = toHex(d)
}
require.Equal(t, types.NewDatum(tt.crypt), d)
}
testNullInput(t, ctx, ast.Decode)
}

func TestSQLEncode(t *testing.T) {
t.Parallel()
ctx := createContext(t)

fc := funcs[ast.Encode]
for _, test := range cryptTests {
password := types.NewDatum(test.password)
cryptStr := fromHex(test.crypt)

f, err := fc.getFunction(ctx, datumsToConstants([]types.Datum{cryptStr, password}))
err := ctx.GetSessionVars().SetSystemVar(variable.CharacterSetConnection, test.chs)
require.NoError(t, err)
str, err := evalBuiltinFunc(f, chunk.Row{})

err = ctx.GetSessionVars().SetSystemVar(variable.CollationConnection, test.chs)
require.NoError(t, err)
var h []byte
if test.crypt != nil {
h, _ = hex.DecodeString(test.crypt.(string))
} else {
h = nil
}
f, err := newFunctionForTest(ctx, ast.Encode, primitiveValsToConstants(ctx, []interface{}{h, test.password})...)
require.NoError(t, err)
require.Equal(t, types.NewDatum(test.origin), str)
d, err := f.Eval(chunk.Row{})
require.NoError(t, err)
if test.origin != nil {
result, err := charset.NewEncoding(test.chs).EncodeString(test.origin.(string))
require.NoError(t, err)
require.Equal(t, types.NewCollationStringDatum(result, test.chs, 1), d)
} else {
result := types.NewDatum(test.origin)
require.Equal(t, result.GetBytes(), d.GetBytes())
}
}
testNullInput(t, ctx, ast.Encode)
}
Expand Down
30 changes: 26 additions & 4 deletions expression/builtin_encryption_vec.go
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,9 @@ func (b *builtinDecodeSig) vecEvalString(input *chunk.Chunk, result *chunk.Colum
if err := b.args[0].VecEvalString(b.ctx, input, buf); err != nil {
return err
}
dataTp := b.args[0].GetType()
dataEnc := charset.NewEncoding(dataTp.Charset)

buf1, err1 := b.bufAllocator.get()
if err1 != nil {
return err1
Expand All @@ -220,14 +223,22 @@ func (b *builtinDecodeSig) vecEvalString(input *chunk.Chunk, result *chunk.Colum
if err := b.args[1].VecEvalString(b.ctx, input, buf1); err != nil {
return err
}
passwordTp := b.args[1].GetType()
passwordEnc := charset.NewEncoding(passwordTp.Charset)
result.ReserveString(n)
for i := 0; i < n; i++ {
if buf.IsNull(i) || buf1.IsNull(i) {
result.AppendNull()
continue
}
dataStr := buf.GetString(i)
passwordStr := buf1.GetString(i)
dataStr, err := dataEnc.EncodeString(buf.GetString(i))
if err != nil {
return err
}
passwordStr, err := passwordEnc.EncodeString(buf1.GetString(i))
if err != nil {
return err
}
decodeStr, err := encrypt.SQLDecode(dataStr, passwordStr)
if err != nil {
return err
Expand Down Expand Up @@ -255,18 +266,29 @@ func (b *builtinEncodeSig) vecEvalString(input *chunk.Chunk, result *chunk.Colum
if err1 != nil {
return err1
}
dataTp := b.args[0].GetType()
dataEnc := charset.NewEncoding(dataTp.Charset)
defer b.bufAllocator.put(buf1)
if err := b.args[1].VecEvalString(b.ctx, input, buf1); err != nil {
return err
}
passwordTp := b.args[1].GetType()
passwordEnc := charset.NewEncoding(passwordTp.Charset)
result.ReserveString(n)
for i := 0; i < n; i++ {
if buf.IsNull(i) || buf1.IsNull(i) {
result.AppendNull()
continue
}
decodeStr := buf.GetString(i)
passwordStr := buf1.GetString(i)

decodeStr, err := dataEnc.EncodeString(buf.GetString(i))
if err != nil {
return err
}
passwordStr, err := passwordEnc.EncodeString(buf1.GetString(i))
if err != nil {
return err
}
dataStr, err := encrypt.SQLEncode(decodeStr, passwordStr)
if err != nil {
return err
Expand Down

0 comments on commit 9d9915b

Please sign in to comment.