diff --git a/expression/bench_test.go b/expression/bench_test.go index 7157ef232dcf2..872ec6f05f0a4 100644 --- a/expression/bench_test.go +++ b/expression/bench_test.go @@ -1452,7 +1452,7 @@ func genVecBuiltinFuncBenchCase(ctx sessionctx.Context, funcName string, testCas case types.ETJson: fc = &castAsJSONFunctionClass{baseFunctionClass{ast.Cast, 1, 1}, tp} case types.ETString: - fc = &castAsStringFunctionClass{baseFunctionClass{ast.Cast, 1, 1}, tp} + fc = &castAsStringFunctionClass{baseFunctionClass{ast.Cast, 1, 1}, tp, false} } baseFunc, err = fc.getFunction(ctx, cols) } else if funcName == ast.GetVar { diff --git a/expression/builtin.go b/expression/builtin.go index f10cf9aa3dfa9..09fcad85ea23e 100644 --- a/expression/builtin.go +++ b/expression/builtin.go @@ -418,6 +418,36 @@ func newBaseBuiltinCastFunc(builtinFunc baseBuiltinFunc, inUnion bool) baseBuilt } } +func newBaseBuiltinCastFunc4String(ctx sessionctx.Context, funcName string, args []Expression, tp *types.FieldType, isExplicitCharset bool) (baseBuiltinFunc, error) { + var bf baseBuiltinFunc + var err error + if isExplicitCharset { + bf = baseBuiltinFunc{ + bufAllocator: newLocalColumnPool(), + childrenVectorizedOnce: new(sync.Once), + + args: args, + ctx: ctx, + tp: tp, + } + bf.SetCharsetAndCollation(tp.GetCharset(), tp.GetCollate()) + bf.setCollator(collate.GetCollator(tp.GetCollate())) + bf.SetCoercibility(CoercibilityExplicit) + bf.SetExplicitCharset(true) + if tp.GetCharset() == charset.CharsetASCII { + bf.SetRepertoire(ASCII) + } else { + bf.SetRepertoire(UNICODE) + } + } else { + bf, err = newBaseBuiltinFunc(ctx, funcName, args, tp) + if err != nil { + return baseBuiltinFunc{}, err + } + } + return bf, nil +} + // vecBuiltinFunc contains all vectorized methods for a builtin function. type vecBuiltinFunc interface { // vectorized returns if this builtin function itself supports vectorized evaluation. diff --git a/expression/builtin_cast.go b/expression/builtin_cast.go index 02672ea2b6bb8..3b6fe678f663c 100644 --- a/expression/builtin_cast.go +++ b/expression/builtin_cast.go @@ -270,14 +270,15 @@ func (c *castAsDecimalFunctionClass) getFunction(ctx sessionctx.Context, args [] type castAsStringFunctionClass struct { baseFunctionClass - tp *types.FieldType + tp *types.FieldType + isExplicitCharset bool } func (c *castAsStringFunctionClass) getFunction(ctx sessionctx.Context, args []Expression) (sig builtinFunc, err error) { if err := c.verifyArgs(args); err != nil { return nil, err } - bf, err := newBaseBuiltinFunc(ctx, c.funcName, args, c.tp) + bf, err := newBaseBuiltinCastFunc4String(ctx, c.funcName, args, c.tp, c.isExplicitCharset) if err != nil { return nil, err } @@ -1924,6 +1925,13 @@ func BuildCastCollationFunction(ctx sessionctx.Context, expr Expression, ec *Exp // BuildCastFunction builds a CAST ScalarFunction from the Expression. func BuildCastFunction(ctx sessionctx.Context, expr Expression, tp *types.FieldType) (res Expression) { + res, err := BuildCastFunctionWithCheck(ctx, expr, tp, false) + terror.Log(err) + return +} + +// BuildCastFunctionWithCheck builds a CAST ScalarFunction from the Expression and return error if any. +func BuildCastFunctionWithCheck(ctx sessionctx.Context, expr Expression, tp *types.FieldType, isExplicitCharset bool) (res Expression, err error) { argType := expr.GetType() // If source argument's nullable, then target type should be nullable if !mysql.HasNotNullFlag(argType.GetFlag()) { @@ -1945,13 +1953,12 @@ func BuildCastFunction(ctx sessionctx.Context, expr Expression, tp *types.FieldT case types.ETJson: fc = &castAsJSONFunctionClass{baseFunctionClass{ast.Cast, 1, 1}, tp} case types.ETString: - fc = &castAsStringFunctionClass{baseFunctionClass{ast.Cast, 1, 1}, tp} + fc = &castAsStringFunctionClass{baseFunctionClass{ast.Cast, 1, 1}, tp, isExplicitCharset} if expr.GetType().GetType() == mysql.TypeBit { tp.SetFlen((expr.GetType().GetFlen() + 7) / 8) } } f, err := fc.getFunction(ctx, []Expression{expr}) - terror.Log(err) res = &ScalarFunction{ FuncName: model.NewCIStr(ast.Cast), RetType: tp, @@ -1963,7 +1970,7 @@ func BuildCastFunction(ctx sessionctx.Context, expr Expression, tp *types.FieldT if tp.EvalType() != types.ETJson { res = FoldConstant(res) } - return res + return res, err } // WrapWithCastAsInt wraps `expr` with `cast` if the return type of expr is not diff --git a/expression/builtin_cast_test.go b/expression/builtin_cast_test.go index 2e5c5acbb08ef..80c692673f6b8 100644 --- a/expression/builtin_cast_test.go +++ b/expression/builtin_cast_test.go @@ -646,7 +646,7 @@ func TestCastFuncSig(t *testing.T) { tp := types.NewFieldType(mysql.TypeVarString) tp.SetCharset(charset.CharsetBin) args := []Expression{c.before} - stringFunc, err := newBaseBuiltinFunc(ctx, "", args, tp) + stringFunc, err := newBaseBuiltinCastFunc4String(ctx, "", args, tp, false) require.NoError(t, err) switch i { case 0: @@ -725,7 +725,7 @@ func TestCastFuncSig(t *testing.T) { tp := types.NewFieldType(mysql.TypeVarString) tp.SetFlen(c.flen) tp.SetCharset(charset.CharsetBin) - stringFunc, err := newBaseBuiltinFunc(ctx, "", args, tp) + stringFunc, err := newBaseBuiltinCastFunc4String(ctx, "", args, tp, false) require.NoError(t, err) switch i { case 0: @@ -1074,7 +1074,7 @@ func TestCastFuncSig(t *testing.T) { // null case args := []Expression{&Column{RetType: types.NewFieldType(mysql.TypeDouble), Index: 0}} row := chunk.MutRowFromDatums([]types.Datum{types.NewDatum(nil)}) - bf, err := newBaseBuiltinFunc(ctx, "", args, types.NewFieldType(mysql.TypeVarString)) + bf, err := newBaseBuiltinCastFunc4String(ctx, "", args, types.NewFieldType(mysql.TypeVarString), false) require.NoError(t, err) sig = &builtinCastRealAsStringSig{bf} sRes, isNull, err := sig.evalString(row.ToRow()) diff --git a/expression/collation.go b/expression/collation.go index e44a0c255bfbf..03975108b6f99 100644 --- a/expression/collation.go +++ b/expression/collation.go @@ -44,6 +44,8 @@ type collationInfo struct { charset string collation string + + isExplicitCharset bool } func (c *collationInfo) HasCoercibility() bool { @@ -76,6 +78,14 @@ func (c *collationInfo) CharsetAndCollation() (string, string) { return c.charset, c.collation } +func (c *collationInfo) IsExplicitCharset() bool { + return c.isExplicitCharset +} + +func (c *collationInfo) SetExplicitCharset(explicit bool) { + c.isExplicitCharset = explicit +} + // CollationInfo contains all interfaces about dealing with collation. type CollationInfo interface { // HasCoercibility returns if the Coercibility value is initialized. @@ -98,6 +108,12 @@ type CollationInfo interface { // SetCharsetAndCollation sets charset and collation. SetCharsetAndCollation(chs, coll string) + + // IsExplicitCharset return the charset is explicit set or not. + IsExplicitCharset() bool + + // SetExplicitCharset set the charset is explicit or not. + SetExplicitCharset(bool) } // Coercibility values are used to check whether the collation of one item can be coerced to @@ -246,9 +262,8 @@ func deriveCollation(ctx sessionctx.Context, funcName string, args []Expression, case ast.Cast: // We assume all the cast are implicit. ec = &ExprCollation{args[0].Coercibility(), args[0].Repertoire(), args[0].GetType().GetCharset(), args[0].GetType().GetCollate()} - // Non-string type cast to string type should use @@character_set_connection and @@collation_connection. - // String type cast to string type should keep its original charset and collation. It should not happen. - if retType == types.ETString && argTps[0] != types.ETString { + // Cast to string type should use @@character_set_connection and @@collation_connection. + if retType == types.ETString { ec.Charset, ec.Collation = ctx.GetSessionVars().GetCharsetInfo() } return ec, nil diff --git a/expression/integration_test.go b/expression/integration_test.go index 5f96b0ac5c164..4335ed35ef5ac 100644 --- a/expression/integration_test.go +++ b/expression/integration_test.go @@ -8078,3 +8078,26 @@ func TestTestIssue53580(t *testing.T) { ) then 1 else 2 end; `).Check(testkit.Rows()) } + +func TestCastAsStringExplicitCharSet(t *testing.T) { + store := testkit.CreateMockStore(t) + tk := testkit.NewTestKit(t, store) + + tk.MustExec("use test") + + tk.MustExec("CREATE TABLE `test` (" + + " `id` bigint(20) NOT NULL," + + " `update_user` varchar(32) DEFAULT NULL," + + " PRIMARY KEY (`id`) /*T![clustered_index] CLUSTERED */" + + ") ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_bin") + tk.MustExec("insert into test values(1,'张三'), (2,'李四'), (3,'张三'), (4,'李四')") + tk.MustQuery("select id from test order by cast(update_user as char character set gbk) desc , id limit 3").Check(testkit.Rows("1", "3", "2")) + + tk.MustExec("drop table test") + tk.MustExec("create table test (`id` bigint NOT NULL," + + " `update_user` varchar(32) CHARACTER SET gbk COLLATE gbk_chinese_ci DEFAULT NULL," + + " PRIMARY KEY (`id`)" + + ") ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_bin") + tk.MustExec("insert into test values(1,'张三'), (2,'李四'), (3,'张三'), (4,'李四')") + tk.MustQuery("select id from test order by cast(update_user as char) desc , id limit 3").Check(testkit.Rows("2", "4", "1")) +} diff --git a/expression/scalar_function.go b/expression/scalar_function.go index 0cf8050d61230..d639747596a19 100644 --- a/expression/scalar_function.go +++ b/expression/scalar_function.go @@ -620,6 +620,16 @@ func (sf *ScalarFunction) SetRepertoire(r Repertoire) { sf.Function.SetRepertoire(r) } +// IsExplicitCharset return the charset is explicit set or not. +func (sf *ScalarFunction) IsExplicitCharset() bool { + return sf.Function.IsExplicitCharset() +} + +// SetExplicitCharset set the charset is explicit or not. +func (sf *ScalarFunction) SetExplicitCharset(explicit bool) { + sf.Function.SetExplicitCharset(explicit) +} + const emptyScalarFunctionSize = int64(unsafe.Sizeof(ScalarFunction{})) // MemoryUsage return the memory usage of ScalarFunction diff --git a/expression/util.go b/expression/util.go index e51697f19796a..75166f12b9e73 100644 --- a/expression/util.go +++ b/expression/util.go @@ -428,7 +428,8 @@ func ColumnSubstituteImpl(expr Expression, schema *Schema, newExprs []Expression } if substituted { flag := v.RetType.GetFlag() - e := BuildCastFunction(v.GetCtx(), newArg, v.RetType) + e, err := BuildCastFunctionWithCheck(v.GetCtx(), newArg, v.RetType, v.Function.IsExplicitCharset()) + terror.Log(err) e.SetCoercibility(v.Coercibility()) e.GetType().SetFlag(flag) return true, false, e diff --git a/expression/util_test.go b/expression/util_test.go index f462bdd9c8a0c..8a39bdedac213 100644 --- a/expression/util_test.go +++ b/expression/util_test.go @@ -590,6 +590,8 @@ func (m *MockExpr) Coercibility() Coercibility func (m *MockExpr) SetCoercibility(Coercibility) {} func (m *MockExpr) Repertoire() Repertoire { return UNICODE } func (m *MockExpr) SetRepertoire(Repertoire) {} +func (m *MockExpr) IsExplicitCharset() bool { return false } +func (m *MockExpr) SetExplicitCharset(bool) {} func (m *MockExpr) CharsetAndCollation() (string, string) { return "", "" diff --git a/planner/core/expression_rewriter.go b/planner/core/expression_rewriter.go index 94c7467dac851..f2fc90a802ef6 100644 --- a/planner/core/expression_rewriter.go +++ b/planner/core/expression_rewriter.go @@ -1209,7 +1209,11 @@ func (er *expressionRewriter) Leave(originInNode ast.Node) (retNode ast.Node, ok return retNode, false } - castFunction := expression.BuildCastFunction(er.sctx, arg, v.Tp) + castFunction, err := expression.BuildCastFunctionWithCheck(er.sctx, arg, v.Tp, v.ExplicitCharSet) + if err != nil { + er.err = err + return retNode, false + } if v.Tp.EvalType() == types.ETString { castFunction.SetCoercibility(expression.CoercibilityImplicit) if v.Tp.GetCharset() == charset.CharsetASCII {