diff --git a/pkg/expression/bench_test.go b/pkg/expression/bench_test.go index 87bfe9d7905df..f9cbfcdc43522 100644 --- a/pkg/expression/bench_test.go +++ b/pkg/expression/bench_test.go @@ -1449,7 +1449,7 @@ func genVecBuiltinFuncBenchCase(ctx BuildContext, funcName string, testCase vecE case types.ETJson: fc = &castAsJSONFunctionClass{baseFunctionClass{ast.Cast, 1, 1}, tp} case types.ETString: - fc = &castAsStringFunctionClass{baseFunctionClass{ast.Cast, 1, 1}, tp} + fc = &castAsStringFunctionClass{baseFunctionClass{ast.Cast, 1, 1}, tp, false} } baseFunc, err = fc.getFunction(ctx, cols) } else if funcName == ast.GetVar { diff --git a/pkg/expression/builtin.go b/pkg/expression/builtin.go index 5db08b06cb151..b18a037d4d9b4 100644 --- a/pkg/expression/builtin.go +++ b/pkg/expression/builtin.go @@ -448,6 +448,35 @@ func newBaseBuiltinCastFunc(builtinFunc baseBuiltinFunc, inUnion bool) baseBuilt } } +func newBaseBuiltinCastFunc4String(ctx BuildContext, funcName string, args []Expression, tp *types.FieldType, isExplicitCharset bool) (baseBuiltinFunc, error) { + var bf baseBuiltinFunc + var err error + if isExplicitCharset { + bf = baseBuiltinFunc{ + bufAllocator: newLocalColumnPool(), + childrenVectorizedOnce: new(sync.Once), + + args: args, + tp: tp, + } + bf.SetCharsetAndCollation(tp.GetCharset(), tp.GetCollate()) + bf.setCollator(collate.GetCollator(tp.GetCollate())) + bf.SetCoercibility(CoercibilityExplicit) + bf.SetExplicitCharset(true) + if tp.GetCharset() == charset.CharsetASCII { + bf.SetRepertoire(ASCII) + } else { + bf.SetRepertoire(UNICODE) + } + } else { + bf, err = newBaseBuiltinFunc(ctx, funcName, args, tp) + if err != nil { + return baseBuiltinFunc{}, err + } + } + return bf, nil +} + // vecBuiltinFunc contains all vectorized methods for a builtin function. type vecBuiltinFunc interface { // vectorized returns if this builtin function itself supports vectorized evaluation. diff --git a/pkg/expression/builtin_cast.go b/pkg/expression/builtin_cast.go index 33ee29323f73a..3eeba359aa43e 100644 --- a/pkg/expression/builtin_cast.go +++ b/pkg/expression/builtin_cast.go @@ -270,14 +270,15 @@ func (c *castAsDecimalFunctionClass) getFunction(ctx BuildContext, args []Expres type castAsStringFunctionClass struct { baseFunctionClass - tp *types.FieldType + tp *types.FieldType + isExplicitCharset bool } func (c *castAsStringFunctionClass) getFunction(ctx BuildContext, args []Expression) (sig builtinFunc, err error) { if err := c.verifyArgs(args); err != nil { return nil, err } - bf, err := newBaseBuiltinFunc(ctx, c.funcName, args, c.tp) + bf, err := newBaseBuiltinCastFunc4String(ctx, c.funcName, args, c.tp, c.isExplicitCharset) if err != nil { return nil, err } @@ -2057,7 +2058,9 @@ func BuildCastFunction4Union(ctx BuildContext, expr Expression, tp *types.FieldT defer func() { ctx.SetValue(inUnionCastContext, nil) }() - return BuildCastFunction(ctx, expr, tp) + res, err := BuildCastFunctionWithCheck(ctx, expr, tp, false) + terror.Log(err) + return } // BuildCastCollationFunction builds a ScalarFunction which casts the collation. @@ -2092,13 +2095,13 @@ func BuildCastCollationFunction(ctx BuildContext, expr Expression, ec *ExprColla // BuildCastFunction builds a CAST ScalarFunction from the Expression. func BuildCastFunction(ctx BuildContext, expr Expression, tp *types.FieldType) (res Expression) { - res, err := BuildCastFunctionWithCheck(ctx, expr, tp) + res, err := BuildCastFunctionWithCheck(ctx, expr, tp, false) terror.Log(err) return } // BuildCastFunctionWithCheck builds a CAST ScalarFunction from the Expression and return error if any. -func BuildCastFunctionWithCheck(ctx BuildContext, expr Expression, tp *types.FieldType) (res Expression, err error) { +func BuildCastFunctionWithCheck(ctx BuildContext, expr Expression, tp *types.FieldType, isExplicitCharset bool) (res Expression, err error) { argType := expr.GetType() // If source argument's nullable, then target type should be nullable if !mysql.HasNotNullFlag(argType.GetFlag()) { @@ -2124,7 +2127,7 @@ func BuildCastFunctionWithCheck(ctx BuildContext, expr Expression, tp *types.Fie fc = &castAsJSONFunctionClass{baseFunctionClass{ast.Cast, 1, 1}, tp} } case types.ETString: - fc = &castAsStringFunctionClass{baseFunctionClass{ast.Cast, 1, 1}, tp} + fc = &castAsStringFunctionClass{baseFunctionClass{ast.Cast, 1, 1}, tp, isExplicitCharset} if expr.GetType().GetType() == mysql.TypeBit { tp.SetFlen((expr.GetType().GetFlen() + 7) / 8) } diff --git a/pkg/expression/builtin_cast_test.go b/pkg/expression/builtin_cast_test.go index d2c54affc2d1b..eeeeb32aca27c 100644 --- a/pkg/expression/builtin_cast_test.go +++ b/pkg/expression/builtin_cast_test.go @@ -655,7 +655,7 @@ func TestCastFuncSig(t *testing.T) { tp := types.NewFieldType(mysql.TypeVarString) tp.SetCharset(charset.CharsetBin) args := []Expression{c.before} - stringFunc, err := newBaseBuiltinFunc(ctx, "", args, tp) + stringFunc, err := newBaseBuiltinCastFunc4String(ctx, "", args, tp, false) require.NoError(t, err) switch i { case 0: @@ -742,7 +742,7 @@ func TestCastFuncSig(t *testing.T) { tp := types.NewFieldType(mysql.TypeVarString) tp.SetFlen(c.flen) tp.SetCharset(charset.CharsetBin) - stringFunc, err := newBaseBuiltinFunc(ctx, "", args, tp) + stringFunc, err := newBaseBuiltinCastFunc4String(ctx, "", args, tp, false) require.NoError(t, err) switch i { case 0: @@ -1099,7 +1099,7 @@ func TestCastFuncSig(t *testing.T) { // null case args := []Expression{&Column{RetType: types.NewFieldType(mysql.TypeDouble), Index: 0}} row := chunk.MutRowFromDatums([]types.Datum{types.NewDatum(nil)}) - bf, err := newBaseBuiltinFunc(ctx, "", args, types.NewFieldType(mysql.TypeVarString)) + bf, err := newBaseBuiltinCastFunc4String(ctx, "", args, types.NewFieldType(mysql.TypeVarString), false) require.NoError(t, err) sig = &builtinCastRealAsStringSig{bf} sRes, err := evalBuiltinFunc(sig, ctx, row.ToRow()) @@ -1694,7 +1694,7 @@ func TestCastArrayFunc(t *testing.T) { }, } for _, tt := range tbl { - f, err := BuildCastFunctionWithCheck(ctx, datumsToConstants(types.MakeDatums(types.CreateBinaryJSON(tt.input)))[0], tt.tp) + f, err := BuildCastFunctionWithCheck(ctx, datumsToConstants(types.MakeDatums(types.CreateBinaryJSON(tt.input)))[0], tt.tp, false) if !tt.buildFuncSuccess { require.Error(t, err, tt.input) continue diff --git a/pkg/expression/collation.go b/pkg/expression/collation.go index bd96f34bd41d6..a4d5b905d4bfd 100644 --- a/pkg/expression/collation.go +++ b/pkg/expression/collation.go @@ -43,6 +43,8 @@ type collationInfo struct { charset string collation string + + isExplicitCharset bool } func (c *collationInfo) HasCoercibility() bool { @@ -75,6 +77,14 @@ func (c *collationInfo) CharsetAndCollation() (string, string) { return c.charset, c.collation } +func (c *collationInfo) IsExplicitCharset() bool { + return c.isExplicitCharset +} + +func (c *collationInfo) SetExplicitCharset(explicit bool) { + c.isExplicitCharset = explicit +} + // CollationInfo contains all interfaces about dealing with collation. type CollationInfo interface { // HasCoercibility returns if the Coercibility value is initialized. @@ -97,6 +107,12 @@ type CollationInfo interface { // SetCharsetAndCollation sets charset and collation. SetCharsetAndCollation(chs, coll string) + + // IsExplicitCharset return the charset is explicit set or not. + IsExplicitCharset() bool + + // SetExplicitCharset set the charset is explicit or not. + SetExplicitCharset(bool) } // Coercibility values are used to check whether the collation of one item can be coerced to @@ -245,9 +261,8 @@ func deriveCollation(ctx BuildContext, funcName string, args []Expression, retTy case ast.Cast: // We assume all the cast are implicit. ec = &ExprCollation{args[0].Coercibility(), args[0].Repertoire(), args[0].GetType().GetCharset(), args[0].GetType().GetCollate()} - // Non-string type cast to string type should use @@character_set_connection and @@collation_connection. - // String type cast to string type should keep its original charset and collation. It should not happen. - if retType == types.ETString && argTps[0] != types.ETString { + // Cast to string type should use @@character_set_connection and @@collation_connection. + if retType == types.ETString { ec.Charset, ec.Collation = ctx.GetCharsetInfo() } return ec, nil diff --git a/pkg/expression/scalar_function.go b/pkg/expression/scalar_function.go index 438e8b359a1b0..bb8520f4f8287 100644 --- a/pkg/expression/scalar_function.go +++ b/pkg/expression/scalar_function.go @@ -824,6 +824,16 @@ func (sf *ScalarFunction) SetRepertoire(r Repertoire) { sf.Function.SetRepertoire(r) } +// IsExplicitCharset return the charset is explicit set or not. +func (sf *ScalarFunction) IsExplicitCharset() bool { + return sf.Function.IsExplicitCharset() +} + +// SetExplicitCharset set the charset is explicit or not. +func (sf *ScalarFunction) SetExplicitCharset(explicit bool) { + sf.Function.SetExplicitCharset(explicit) +} + const emptyScalarFunctionSize = int64(unsafe.Sizeof(ScalarFunction{})) // MemoryUsage return the memory usage of ScalarFunction diff --git a/pkg/expression/util.go b/pkg/expression/util.go index c3f92d69feabe..f1eb5b14e7c5c 100644 --- a/pkg/expression/util.go +++ b/pkg/expression/util.go @@ -455,8 +455,10 @@ func ColumnSubstituteImpl(ctx BuildContext, expr Expression, schema *Schema, new if substituted { flag := v.RetType.GetFlag() var e Expression + var err error if v.FuncName.L == ast.Cast { - e = BuildCastFunction(ctx, newArg, v.RetType) + e, err = BuildCastFunctionWithCheck(ctx, newArg, v.RetType, v.Function.IsExplicitCharset()) + terror.Log(err) } else { // for grouping function recreation, use clone (meta included) instead of newFunction e = v.Clone() diff --git a/pkg/expression/util_test.go b/pkg/expression/util_test.go index 25186011246d6..efd6a96cc1da4 100644 --- a/pkg/expression/util_test.go +++ b/pkg/expression/util_test.go @@ -593,6 +593,8 @@ func (m *MockExpr) Coercibility() Coercibility { return func (m *MockExpr) SetCoercibility(Coercibility) {} func (m *MockExpr) Repertoire() Repertoire { return UNICODE } func (m *MockExpr) SetRepertoire(Repertoire) {} +func (m *MockExpr) IsExplicitCharset() bool { return false } +func (m *MockExpr) SetExplicitCharset(bool) {} func (m *MockExpr) CharsetAndCollation() (string, string) { return "", "" diff --git a/pkg/planner/core/expression_rewriter.go b/pkg/planner/core/expression_rewriter.go index 8a278e53fa508..d1575324d7173 100644 --- a/pkg/planner/core/expression_rewriter.go +++ b/pkg/planner/core/expression_rewriter.go @@ -1485,7 +1485,7 @@ func (er *expressionRewriter) Leave(originInNode ast.Node) (retNode ast.Node, ok return retNode, false } - castFunction, err := expression.BuildCastFunctionWithCheck(er.sctx, arg, v.Tp) + castFunction, err := expression.BuildCastFunctionWithCheck(er.sctx, arg, v.Tp, v.ExplicitCharSet) if err != nil { er.err = err return retNode, false diff --git a/tests/integrationtest/r/executor/executor.result b/tests/integrationtest/r/executor/executor.result index cbb02d5e20027..e2ed62b485521 100644 --- a/tests/integrationtest/r/executor/executor.result +++ b/tests/integrationtest/r/executor/executor.result @@ -4578,4 +4578,5 @@ LOCK TABLE executor__executor.t WRITE, test2.t2 WRITE; LOCK TABLE executor__executor.t WRITE, test2.t2 WRITE; Error 8020 (HY000): Table 't' was locked in WRITE by server: session: unlock tables; +unlock tables; drop user 'testuser'@'localhost'; diff --git a/tests/integrationtest/r/expression/cast.result b/tests/integrationtest/r/expression/cast.result index 76c9977dc52fb..a919c1e53ccf2 100644 --- a/tests/integrationtest/r/expression/cast.result +++ b/tests/integrationtest/r/expression/cast.result @@ -148,3 +148,33 @@ select 1.194192591e9 > t0.c0 from t0; select 1.194192591e9 < t0.c0 from t0; 1.194192591e9 < t0.c0 0 +drop table if exists test; +CREATE TABLE `test` ( +`id` bigint(20) NOT NULL, +`update_user` varchar(32) DEFAULT NULL, +PRIMARY KEY (`id`) /*T![clustered_index] CLUSTERED */ +) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_bin; +insert into test values(1,'张三'); +insert into test values(2,'李四'); +insert into test values(3,'张三'); +insert into test values(4,'李四'); +select * from test order by cast(update_user as char character set gbk) desc , id limit 3; +id update_user +1 张三 +3 张三 +2 李四 +drop table test; +CREATE TABLE `test` ( +`id` bigint NOT NULL, +`update_user` varchar(32) CHARACTER SET gbk COLLATE gbk_chinese_ci DEFAULT NULL, +PRIMARY KEY (`id`) +) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_bin; +insert into test values(1,'张三'); +insert into test values(2,'李四'); +insert into test values(3,'张三'); +insert into test values(4,'李四'); +select * from test order by cast(update_user as char) desc , id limit 3; +id update_user +2 李四 +4 李四 +1 张三 diff --git a/tests/integrationtest/t/executor/executor.test b/tests/integrationtest/t/executor/executor.test index e177f2cd6df01..8c910e62e517f 100644 --- a/tests/integrationtest/t/executor/executor.test +++ b/tests/integrationtest/t/executor/executor.test @@ -2909,6 +2909,9 @@ connection default; --error 8020 LOCK TABLE executor__executor.t WRITE, test2.t2 WRITE; +connection conn1; +unlock tables; + disconnect conn1; unlock tables; drop user 'testuser'@'localhost'; diff --git a/tests/integrationtest/t/expression/cast.test b/tests/integrationtest/t/expression/cast.test index 03843628ab9e6..7af48437383a9 100644 --- a/tests/integrationtest/t/expression/cast.test +++ b/tests/integrationtest/t/expression/cast.test @@ -87,3 +87,29 @@ select t0.c0 > 1.194192591e9 from t0; select t0.c0 < 1.194192591e9 from t0; select 1.194192591e9 > t0.c0 from t0; select 1.194192591e9 < t0.c0 from t0; + +# TestCastAsStringExplicitCharSet +drop table if exists test; +CREATE TABLE `test` ( + `id` bigint(20) NOT NULL, + `update_user` varchar(32) DEFAULT NULL, + PRIMARY KEY (`id`) /*T![clustered_index] CLUSTERED */ +) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_bin; +insert into test values(1,'张三'); +insert into test values(2,'李四'); +insert into test values(3,'张三'); +insert into test values(4,'李四'); +select * from test order by cast(update_user as char character set gbk) desc , id limit 3; + +drop table test; +CREATE TABLE `test` ( + `id` bigint NOT NULL, + `update_user` varchar(32) CHARACTER SET gbk COLLATE gbk_chinese_ci DEFAULT NULL, + PRIMARY KEY (`id`) +) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_bin; +insert into test values(1,'张三'); +insert into test values(2,'李四'); +insert into test values(3,'张三'); +insert into test values(4,'李四'); +select * from test order by cast(update_user as char) desc , id limit 3; +