diff --git a/expression/builtin.go b/expression/builtin.go index 2ec8672c5cce1..a7e9537e84a52 100644 --- a/expression/builtin.go +++ b/expression/builtin.go @@ -140,6 +140,7 @@ func newBaseBuiltinFuncWithTp(ctx sessionctx.Context, funcName string, args []Ex args[i] = WrapWithCastAsDecimal(ctx, args[i]) case types.ETString: args[i] = WrapWithCastAsString(ctx, args[i]) + args[i] = WrapWithToBinary(ctx, args[i], funcName) case types.ETDatetime: args[i] = WrapWithCastAsTime(ctx, args[i], types.NewFieldType(mysql.TypeDatetime)) case types.ETTimestamp: @@ -879,6 +880,9 @@ var funcs = map[string]functionClass{ ast.NextVal: &nextValFunctionClass{baseFunctionClass{ast.NextVal, 1, 1}}, ast.LastVal: &lastValFunctionClass{baseFunctionClass{ast.LastVal, 1, 1}}, ast.SetVal: &setValFunctionClass{baseFunctionClass{ast.SetVal, 2, 2}}, + + // TiDB implicit internal functions. + InternalFuncToBinary: &tidbConvertCharsetFunctionClass{baseFunctionClass{InternalFuncToBinary, 1, 1}}, } // IsFunctionSupported check if given function name is a builtin sql function. @@ -902,6 +906,7 @@ func GetDisplayName(name string) string { func GetBuiltinList() []string { res := make([]string, 0, len(funcs)) notImplementedFunctions := []string{ast.RowFunc, ast.IsTruthWithNull} + implicitFunctions := []string{InternalFuncToBinary} for funcName := range funcs { skipFunc := false // Skip not implemented functions @@ -910,6 +915,11 @@ func GetBuiltinList() []string { skipFunc = true } } + for _, implicitFunc := range implicitFunctions { + if funcName == implicitFunc { + skipFunc = true + } + } // Skip literal functions // (their names are not readable: 'tidb`.(dateliteral, for example) // See: https://github.com/pingcap/parser/pull/591 diff --git a/expression/builtin_convert_charset.go b/expression/builtin_convert_charset.go new file mode 100644 index 0000000000000..243a50e5d0a4e --- /dev/null +++ b/expression/builtin_convert_charset.go @@ -0,0 +1,136 @@ +// Copyright 2021 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package expression + +import ( + "fmt" + + "github.com/pingcap/tidb/parser/ast" + "github.com/pingcap/tidb/parser/charset" + "github.com/pingcap/tidb/parser/model" + "github.com/pingcap/tidb/sessionctx" + "github.com/pingcap/tidb/types" + "github.com/pingcap/tidb/util/chunk" + "github.com/pingcap/tipb/go-tipb" +) + +// InternalFuncToBinary accepts a string and returns another string encoded in a given charset. +const InternalFuncToBinary = "to_binary" + +type tidbConvertCharsetFunctionClass struct { + baseFunctionClass +} + +func (c *tidbConvertCharsetFunctionClass) getFunction(ctx sessionctx.Context, args []Expression) (builtinFunc, error) { + if err := c.verifyArgs(args); err != nil { + return nil, c.verifyArgs(args) + } + argTp := args[0].GetType().EvalType() + var sig builtinFunc + switch argTp { + case types.ETString: + bf, err := newBaseBuiltinFuncWithTp(ctx, c.funcName, args, types.ETString, types.ETString) + if err != nil { + return nil, err + } + sig = &builtinInternalToBinarySig{bf} + sig.setPbCode(tipb.ScalarFuncSig_ToBinary) + default: + return nil, fmt.Errorf("unexpected argTp: %d", argTp) + } + return sig, nil +} + +var _ builtinFunc = &builtinInternalToBinarySig{} + +type builtinInternalToBinarySig struct { + baseBuiltinFunc +} + +func (b *builtinInternalToBinarySig) Clone() builtinFunc { + newSig := &builtinInternalToBinarySig{} + newSig.cloneFrom(&b.baseBuiltinFunc) + return newSig +} + +func (b *builtinInternalToBinarySig) evalString(row chunk.Row) (res string, isNull bool, err error) { + val, isNull, err := b.args[0].EvalString(b.ctx, row) + if isNull || err != nil { + return res, isNull, err + } + tp := b.args[0].GetType() + enc := charset.NewEncoding(tp.Charset) + res, err = enc.EncodeString(val) + return res, false, err +} + +func (b *builtinInternalToBinarySig) vectorized() bool { + return true +} + +func (b *builtinInternalToBinarySig) vecEvalString(input *chunk.Chunk, result *chunk.Column) error { + n := input.NumRows() + buf, err := b.bufAllocator.get() + if err != nil { + return err + } + defer b.bufAllocator.put(buf) + if err := b.args[0].VecEvalString(b.ctx, input, buf); err != nil { + return err + } + enc := charset.NewEncoding(b.args[0].GetType().Charset) + result.ReserveString(n) + for i := 0; i < n; i++ { + var str string + if buf.IsNull(i) { + result.AppendNull() + continue + } + str = buf.GetString(i) + str, err = enc.EncodeString(str) + if err != nil { + return err + } + result.AppendString(str) + } + return nil +} + +// toBinaryMap contains the builtin functions which arguments need to be converted to the correct charset. +var toBinaryMap = map[string]struct{}{ + ast.Hex: {}, ast.Length: {}, ast.OctetLength: {}, ast.ASCII: {}, + ast.ToBase64: {}, +} + +// WrapWithToBinary wraps `expr` with to_binary sig. +func WrapWithToBinary(ctx sessionctx.Context, expr Expression, funcName string) Expression { + exprTp := expr.GetType() + if _, err := charset.GetDefaultCollationLegacy(exprTp.Charset); err != nil { + if _, ok := toBinaryMap[funcName]; ok { + fc := funcs[InternalFuncToBinary] + sig, err := fc.getFunction(ctx, []Expression{expr}) + if err != nil { + return expr + } + sf := &ScalarFunction{ + FuncName: model.NewCIStr(InternalFuncToBinary), + RetType: exprTp, + Function: sig, + } + return FoldConstant(sf) + } + } + return expr +} diff --git a/expression/builtin_string.go b/expression/builtin_string.go index 13733abd4cc10..bf948605045b2 100644 --- a/expression/builtin_string.go +++ b/expression/builtin_string.go @@ -222,15 +222,6 @@ func (b *builtinLengthSig) evalInt(row chunk.Row) (int64, bool, error) { if isNull || err != nil { return 0, isNull, err } - - argTp := b.args[0].GetType() - if !types.IsBinaryStr(argTp) { - dBytes, err := charset.NewEncoding(argTp.Charset).EncodeString(val) - if err == nil { - return int64(len(dBytes)), false, nil - } - } - return int64(len([]byte(val))), false, nil } @@ -272,13 +263,6 @@ func (b *builtinASCIISig) evalInt(row chunk.Row) (int64, bool, error) { if len(val) == 0 { return 0, false, nil } - argTp := b.args[0].GetType() - if !types.IsBinaryStr(argTp) { - dBytes, err := charset.NewEncoding(argTp.Charset).EncodeString(val) - if err == nil { - return int64(dBytes[0]), false, nil - } - } return int64(val[0]), false, nil } @@ -1664,7 +1648,7 @@ func (c *hexFunctionClass) getFunction(ctx sessionctx.Context, args []Expression argFieldTp := args[0].GetType() // Use UTF8MB4 as default. bf.tp.Flen = argFieldTp.Flen * 4 * 2 - sig := &builtinHexStrArgSig{bf, charset.NewEncoding(argFieldTp.Charset)} + sig := &builtinHexStrArgSig{bf} sig.setPbCode(tipb.ScalarFuncSig_HexStrArg) return sig, nil case types.ETInt, types.ETReal, types.ETDecimal: @@ -1684,15 +1668,11 @@ func (c *hexFunctionClass) getFunction(ctx sessionctx.Context, args []Expression type builtinHexStrArgSig struct { baseBuiltinFunc - encoding *charset.Encoding } func (b *builtinHexStrArgSig) Clone() builtinFunc { newSig := &builtinHexStrArgSig{} newSig.cloneFrom(&b.baseBuiltinFunc) - if b.encoding != nil { - newSig.encoding = charset.NewEncoding(b.encoding.Name()) - } return newSig } @@ -1703,12 +1683,7 @@ func (b *builtinHexStrArgSig) evalString(row chunk.Row) (string, bool, error) { if isNull || err != nil { return d, isNull, err } - dBytes := hack.Slice(d) - dBytes, err = b.encoding.Encode(nil, dBytes) - if err != nil { - return d, false, err - } - return strings.ToUpper(hex.EncodeToString(dBytes)), false, nil + return strings.ToUpper(hex.EncodeToString(hack.Slice(d))), false, nil } type builtinHexIntArgSig struct { @@ -3634,11 +3609,6 @@ func (b *builtinToBase64Sig) evalString(row chunk.Row) (d string, isNull bool, e if isNull || err != nil { return "", isNull, err } - argTp := b.args[0].GetType() - str, err = charset.NewEncoding(argTp.Charset).EncodeString(str) - if err != nil { - return "", false, err - } needEncodeLen := base64NeededEncodedLength(len(str)) if needEncodeLen == -1 { return "", true, nil diff --git a/expression/builtin_string_vec.go b/expression/builtin_string_vec.go index ca21e724f0dfa..11933f305fe58 100644 --- a/expression/builtin_string_vec.go +++ b/expression/builtin_string_vec.go @@ -447,7 +447,6 @@ func (b *builtinHexStrArgSig) vecEvalString(input *chunk.Chunk, result *chunk.Co return err } defer b.bufAllocator.put(buf0) - var encodedBuf []byte if err := b.args[0].VecEvalString(b.ctx, input, buf0); err != nil { return err } @@ -457,13 +456,7 @@ func (b *builtinHexStrArgSig) vecEvalString(input *chunk.Chunk, result *chunk.Co result.AppendNull() continue } - buf0Bytes := buf0.GetBytes(i) - encodedBuf, err = b.encoding.Encode(encodedBuf, buf0Bytes) - if err != nil { - return err - } - buf0Bytes = encodedBuf - result.AppendString(strings.ToUpper(hex.EncodeToString(buf0Bytes))) + result.AppendString(strings.ToUpper(hex.EncodeToString(buf0.GetBytes(i)))) } return nil } @@ -912,11 +905,6 @@ func (b *builtinASCIISig) vecEvalInt(input *chunk.Chunk, result *chunk.Column) e if err = b.args[0].VecEvalString(b.ctx, input, buf); err != nil { return err } - - argTp := b.args[0].GetType() - enc := charset.NewEncoding(argTp.Charset) - isBinaryStr := types.IsBinaryStr(argTp) - result.ResizeInt64(n, false) result.MergeNulls(buf) i64s := result.Int64s() @@ -929,14 +917,6 @@ func (b *builtinASCIISig) vecEvalInt(input *chunk.Chunk, result *chunk.Column) e i64s[i] = 0 continue } - if !isBinaryStr { - dBytes, err := enc.EncodeString(str) - if err != nil { - return err - } - i64s[i] = int64(dBytes[0]) - continue - } i64s[i] = int64(str[0]) } return nil @@ -2162,27 +2142,14 @@ func (b *builtinLengthSig) vecEvalInt(input *chunk.Chunk, result *chunk.Column) return err } - argTp := b.args[0].GetType() - enc := charset.NewEncoding(argTp.Charset) - isBinaryStr := types.IsBinaryStr(argTp) - result.ResizeInt64(n, false) result.MergeNulls(buf) i64s := result.Int64s() - var encodeBuf []byte for i := 0; i < n; i++ { if result.IsNull(i) { continue } str := buf.GetBytes(i) - if !isBinaryStr { - dBytes, err := enc.Encode(encodeBuf, str) - if err != nil { - return err - } - i64s[i] = int64(len(dBytes)) - continue - } i64s[i] = int64(len(str)) } return nil @@ -2470,20 +2437,13 @@ func (b *builtinToBase64Sig) vecEvalString(input *chunk.Chunk, result *chunk.Col if err := b.args[0].VecEvalString(b.ctx, input, buf); err != nil { return err } - - argTp := b.args[0].GetType() - enc := charset.NewEncoding(argTp.Charset) - result.ReserveString(n) for i := 0; i < n; i++ { if buf.IsNull(i) { result.AppendNull() continue } - str, err := enc.EncodeString(buf.GetString(i)) - if err != nil { - return err - } + str := buf.GetString(i) needEncodeLen := base64NeededEncodedLength(len(str)) if needEncodeLen == -1 { result.AppendNull() diff --git a/expression/constant_test.go b/expression/constant_test.go index 08d201ffd6b38..2158b1e4f5d66 100644 --- a/expression/constant_test.go +++ b/expression/constant_test.go @@ -49,9 +49,20 @@ func newLonglong(value int64) *Constant { } } +func newString(value string, collation string) *Constant { + return &Constant{ + Value: types.NewStringDatum(value), + RetType: types.NewFieldTypeWithCollation(mysql.TypeVarchar, collation, 255), + } +} + func newFunction(funcName string, args ...Expression) Expression { - typeLong := types.NewFieldType(mysql.TypeLonglong) - return NewFunctionInternal(mock.NewContext(), funcName, typeLong, args...) + return newFunctionWithType(funcName, mysql.TypeLonglong, args...) +} + +func newFunctionWithType(funcName string, tp byte, args ...Expression) Expression { + ft := types.NewFieldType(tp) + return NewFunctionInternal(mock.NewContext(), funcName, ft, args...) } func TestConstantPropagation(t *testing.T) { @@ -220,6 +231,31 @@ func TestConstantFolding(t *testing.T) { } } +func TestConstantFoldingCharsetConvert(t *testing.T) { + t.Parallel() + tests := []struct { + condition Expression + result string + }{ + { + condition: newFunction(ast.Length, newFunctionWithType( + InternalFuncToBinary, mysql.TypeVarchar, + newString("中文", "gbk_bin"))), + result: "4", + }, + { + condition: newFunction(ast.Length, newFunctionWithType( + InternalFuncToBinary, mysql.TypeVarchar, + newString("中文", "utf8mb4_bin"))), + result: "6", + }, + } + for _, tt := range tests { + newConds := FoldConstant(tt.condition) + require.Equalf(t, tt.result, newConds.String(), "different for expr %s", tt.condition) + } +} + func TestDeferredParamNotNull(t *testing.T) { t.Parallel() diff --git a/expression/distsql_builtin.go b/expression/distsql_builtin.go index d4f0d5eb84fd9..db9b39a2db010 100644 --- a/expression/distsql_builtin.go +++ b/expression/distsql_builtin.go @@ -21,7 +21,6 @@ import ( "time" "github.com/pingcap/errors" - "github.com/pingcap/tidb/parser/charset" "github.com/pingcap/tidb/parser/model" "github.com/pingcap/tidb/parser/mysql" "github.com/pingcap/tidb/sessionctx" @@ -965,11 +964,7 @@ func getSignatureByPB(ctx sessionctx.Context, sigCode tipb.ScalarFuncSig, tp *ti case tipb.ScalarFuncSig_HexIntArg: f = &builtinHexIntArgSig{base} case tipb.ScalarFuncSig_HexStrArg: - chs, args := "utf-8", base.getArgs() - if len(args) == 1 { - chs, _ = args[0].CharsetAndCollation() - } - f = &builtinHexStrArgSig{base, charset.NewEncoding(chs)} + f = &builtinHexStrArgSig{base} case tipb.ScalarFuncSig_InsertUTF8: f = &builtinInsertUTF8Sig{base, maxAllowedPacket} case tipb.ScalarFuncSig_Insert: diff --git a/go.mod b/go.mod index aa4dc6c700ff1..165ed4ad2f538 100644 --- a/go.mod +++ b/go.mod @@ -53,7 +53,7 @@ require ( github.com/pingcap/sysutil v0.0.0-20210730114356-fcd8a63f68c5 github.com/pingcap/tidb-tools v5.2.2-0.20211019062242-37a8bef2fa17+incompatible github.com/pingcap/tidb/parser v0.0.0-20211011031125-9b13dc409c5e - github.com/pingcap/tipb v0.0.0-20211105090418-71142a4d40e3 + github.com/pingcap/tipb v0.0.0-20211116093845-e9b045a0bdf8 github.com/prometheus/client_golang v1.5.1 github.com/prometheus/client_model v0.2.0 github.com/prometheus/common v0.9.1 diff --git a/go.sum b/go.sum index 1237e7f202a53..804792f146573 100644 --- a/go.sum +++ b/go.sum @@ -600,8 +600,8 @@ github.com/pingcap/tidb-dashboard v0.0.0-20211008050453-a25c25809529/go.mod h1:O github.com/pingcap/tidb-dashboard v0.0.0-20211031170437-08e58c069a2a/go.mod h1:OCXbZTBTIMRcIt0jFsuCakZP+goYRv6IjawKbwLS2TQ= github.com/pingcap/tidb-tools v5.2.2-0.20211019062242-37a8bef2fa17+incompatible h1:c7+izmker91NkjkZ6FgTlmD4k1A5FLOAq+li6Ki2/GY= github.com/pingcap/tidb-tools v5.2.2-0.20211019062242-37a8bef2fa17+incompatible/go.mod h1:XGdcy9+yqlDSEMTpOXnwf3hiTeqrV6MN/u1se9N8yIM= -github.com/pingcap/tipb v0.0.0-20211105090418-71142a4d40e3 h1:xnp/Qkk5gELlB8TaY6oro0JNXMBXTafNVxU/vbrNU8I= -github.com/pingcap/tipb v0.0.0-20211105090418-71142a4d40e3/go.mod h1:A7mrd7WHBl1o63LE2bIBGEJMTNWXqhgmYiOvMLxozfs= +github.com/pingcap/tipb v0.0.0-20211116093845-e9b045a0bdf8 h1:Vu/6oq8EFNWgyXRHiclNzTKIu+YKHPCSI/Ba5oVrLtM= +github.com/pingcap/tipb v0.0.0-20211116093845-e9b045a0bdf8/go.mod h1:A7mrd7WHBl1o63LE2bIBGEJMTNWXqhgmYiOvMLxozfs= github.com/pkg/errors v0.8.0/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=