Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

expression: let cast function supports explicit set charset (#55724) #56912

Merged
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion expression/bench_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1452,7 +1452,7 @@ func genVecBuiltinFuncBenchCase(ctx sessionctx.Context, funcName string, testCas
case types.ETJson:
fc = &castAsJSONFunctionClass{baseFunctionClass{ast.Cast, 1, 1}, tp}
case types.ETString:
fc = &castAsStringFunctionClass{baseFunctionClass{ast.Cast, 1, 1}, tp}
fc = &castAsStringFunctionClass{baseFunctionClass{ast.Cast, 1, 1}, tp, false}
}
baseFunc, err = fc.getFunction(ctx, cols)
} else if funcName == ast.GetVar {
Expand Down
30 changes: 30 additions & 0 deletions expression/builtin.go
Original file line number Diff line number Diff line change
Expand Up @@ -418,6 +418,36 @@
}
}

func newBaseBuiltinCastFunc4String(ctx sessionctx.Context, funcName string, args []Expression, tp *types.FieldType, isExplicitCharset bool) (baseBuiltinFunc, error) {
var bf baseBuiltinFunc
var err error
if isExplicitCharset {
bf = baseBuiltinFunc{
bufAllocator: newLocalColumnPool(),
childrenVectorizedOnce: new(sync.Once),

args: args,
ctx: ctx,
tp: tp,
}
bf.SetCharsetAndCollation(tp.GetCharset(), tp.GetCollate())
bf.setCollator(collate.GetCollator(tp.GetCollate()))
bf.SetCoercibility(CoercibilityExplicit)
bf.SetExplicitCharset(true)
if tp.GetCharset() == charset.CharsetASCII {
bf.SetRepertoire(ASCII)
} else {
bf.SetRepertoire(UNICODE)
}
} else {
bf, err = newBaseBuiltinFunc(ctx, funcName, args, tp)
if err != nil {
return baseBuiltinFunc{}, err
}

Check warning on line 446 in expression/builtin.go

View check run for this annotation

Codecov / codecov/patch

expression/builtin.go#L445-L446

Added lines #L445 - L446 were not covered by tests
}
return bf, nil
}

// vecBuiltinFunc contains all vectorized methods for a builtin function.
type vecBuiltinFunc interface {
// vectorized returns if this builtin function itself supports vectorized evaluation.
Expand Down
11 changes: 6 additions & 5 deletions expression/builtin_cast.go
Original file line number Diff line number Diff line change
Expand Up @@ -272,14 +272,15 @@ func (c *castAsDecimalFunctionClass) getFunction(ctx sessionctx.Context, args []
type castAsStringFunctionClass struct {
baseFunctionClass

tp *types.FieldType
tp *types.FieldType
isExplicitCharset bool
}

func (c *castAsStringFunctionClass) getFunction(ctx sessionctx.Context, args []Expression) (sig builtinFunc, err error) {
if err := c.verifyArgs(args); err != nil {
return nil, err
}
bf, err := newBaseBuiltinFunc(ctx, c.funcName, args, c.tp)
bf, err := newBaseBuiltinCastFunc4String(ctx, c.funcName, args, c.tp, c.isExplicitCharset)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -2080,13 +2081,13 @@ func BuildCastCollationFunction(ctx sessionctx.Context, expr Expression, ec *Exp

// BuildCastFunction builds a CAST ScalarFunction from the Expression.
func BuildCastFunction(ctx sessionctx.Context, expr Expression, tp *types.FieldType) (res Expression) {
res, err := BuildCastFunctionWithCheck(ctx, expr, tp)
res, err := BuildCastFunctionWithCheck(ctx, expr, tp, false)
terror.Log(err)
return
}

// BuildCastFunctionWithCheck builds a CAST ScalarFunction from the Expression and return error if any.
func BuildCastFunctionWithCheck(ctx sessionctx.Context, expr Expression, tp *types.FieldType) (res Expression, err error) {
func BuildCastFunctionWithCheck(ctx sessionctx.Context, expr Expression, tp *types.FieldType, isExplicitCharset bool) (res Expression, err error) {
argType := expr.GetType()
// If source argument's nullable, then target type should be nullable
if !mysql.HasNotNullFlag(argType.GetFlag()) {
Expand All @@ -2112,7 +2113,7 @@ func BuildCastFunctionWithCheck(ctx sessionctx.Context, expr Expression, tp *typ
fc = &castAsJSONFunctionClass{baseFunctionClass{ast.Cast, 1, 1}, tp}
}
case types.ETString:
fc = &castAsStringFunctionClass{baseFunctionClass{ast.Cast, 1, 1}, tp}
fc = &castAsStringFunctionClass{baseFunctionClass{ast.Cast, 1, 1}, tp, isExplicitCharset}
if expr.GetType().GetType() == mysql.TypeBit {
tp.SetFlen((expr.GetType().GetFlen() + 7) / 8)
}
Expand Down
12 changes: 5 additions & 7 deletions expression/builtin_cast_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -646,7 +646,7 @@ func TestCastFuncSig(t *testing.T) {
tp := types.NewFieldType(mysql.TypeVarString)
tp.SetCharset(charset.CharsetBin)
args := []Expression{c.before}
stringFunc, err := newBaseBuiltinFunc(ctx, "", args, tp)
stringFunc, err := newBaseBuiltinCastFunc4String(ctx, "", args, tp, false)
require.NoError(t, err)
switch i {
case 0:
Expand Down Expand Up @@ -732,7 +732,7 @@ func TestCastFuncSig(t *testing.T) {
tp := types.NewFieldType(mysql.TypeVarString)
tp.SetFlen(c.flen)
tp.SetCharset(charset.CharsetBin)
stringFunc, err := newBaseBuiltinFunc(ctx, "", args, tp)
stringFunc, err := newBaseBuiltinCastFunc4String(ctx, "", args, tp, false)
require.NoError(t, err)
switch i {
case 0:
Expand Down Expand Up @@ -1083,7 +1083,7 @@ func TestCastFuncSig(t *testing.T) {
// null case
args := []Expression{&Column{RetType: types.NewFieldType(mysql.TypeDouble), Index: 0}}
row := chunk.MutRowFromDatums([]types.Datum{types.NewDatum(nil)})
bf, err := newBaseBuiltinFunc(ctx, "", args, types.NewFieldType(mysql.TypeVarString))
bf, err := newBaseBuiltinCastFunc4String(ctx, "", args, types.NewFieldType(mysql.TypeVarString), false)
require.NoError(t, err)
sig = &builtinCastRealAsStringSig{bf}
sRes, isNull, err := sig.evalString(row.ToRow())
Expand Down Expand Up @@ -1677,10 +1677,8 @@ func TestCastArrayFunc(t *testing.T) {
},
}
for _, tt := range tbl {
f, err := BuildCastFunctionWithCheck(ctx, datumsToConstants(types.MakeDatums(types.CreateBinaryJSON(tt.input)))[0], tt.tp)
if tt.buildFuncSuccess {
require.NoError(t, err, tt.input)
} else {
f, err := BuildCastFunctionWithCheck(ctx, datumsToConstants(types.MakeDatums(types.CreateBinaryJSON(tt.input)))[0], tt.tp, false)
if !tt.buildFuncSuccess {
require.Error(t, err, tt.input)
continue
}
Expand Down
21 changes: 18 additions & 3 deletions expression/collation.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@ type collationInfo struct {

charset string
collation string

isExplicitCharset bool
}

func (c *collationInfo) HasCoercibility() bool {
Expand Down Expand Up @@ -76,6 +78,14 @@ func (c *collationInfo) CharsetAndCollation() (string, string) {
return c.charset, c.collation
}

func (c *collationInfo) IsExplicitCharset() bool {
return c.isExplicitCharset
}

func (c *collationInfo) SetExplicitCharset(explicit bool) {
c.isExplicitCharset = explicit
}

// CollationInfo contains all interfaces about dealing with collation.
type CollationInfo interface {
// HasCoercibility returns if the Coercibility value is initialized.
Expand All @@ -98,6 +108,12 @@ type CollationInfo interface {

// SetCharsetAndCollation sets charset and collation.
SetCharsetAndCollation(chs, coll string)

// IsExplicitCharset return the charset is explicit set or not.
IsExplicitCharset() bool

// SetExplicitCharset set the charset is explicit or not.
SetExplicitCharset(bool)
}

// Coercibility values are used to check whether the collation of one item can be coerced to
Expand Down Expand Up @@ -246,9 +262,8 @@ func deriveCollation(ctx sessionctx.Context, funcName string, args []Expression,
case ast.Cast:
// We assume all the cast are implicit.
ec = &ExprCollation{args[0].Coercibility(), args[0].Repertoire(), args[0].GetType().GetCharset(), args[0].GetType().GetCollate()}
// Non-string type cast to string type should use @@character_set_connection and @@collation_connection.
// String type cast to string type should keep its original charset and collation. It should not happen.
if retType == types.ETString && argTps[0] != types.ETString {
// Cast to string type should use @@character_set_connection and @@collation_connection.
if retType == types.ETString {
ec.Charset, ec.Collation = ctx.GetSessionVars().GetCharsetInfo()
}
return ec, nil
Expand Down
10 changes: 10 additions & 0 deletions expression/scalar_function.go
Original file line number Diff line number Diff line change
Expand Up @@ -626,6 +626,16 @@
sf.Function.SetRepertoire(r)
}

// IsExplicitCharset return the charset is explicit set or not.
func (sf *ScalarFunction) IsExplicitCharset() bool {
return sf.Function.IsExplicitCharset()

Check warning on line 631 in expression/scalar_function.go

View check run for this annotation

Codecov / codecov/patch

expression/scalar_function.go#L630-L631

Added lines #L630 - L631 were not covered by tests
}

// SetExplicitCharset set the charset is explicit or not.
func (sf *ScalarFunction) SetExplicitCharset(explicit bool) {
sf.Function.SetExplicitCharset(explicit)

Check warning on line 636 in expression/scalar_function.go

View check run for this annotation

Codecov / codecov/patch

expression/scalar_function.go#L635-L636

Added lines #L635 - L636 were not covered by tests
}

const emptyScalarFunctionSize = int64(unsafe.Sizeof(ScalarFunction{}))

// MemoryUsage return the memory usage of ScalarFunction
Expand Down
3 changes: 2 additions & 1 deletion expression/util.go
Original file line number Diff line number Diff line change
Expand Up @@ -453,7 +453,8 @@ func ColumnSubstituteImpl(expr Expression, schema *Schema, newExprs []Expression
}
if substituted {
flag := v.RetType.GetFlag()
e := BuildCastFunction(v.GetCtx(), newArg, v.RetType)
e, err := BuildCastFunctionWithCheck(v.GetCtx(), newArg, v.RetType, v.Function.IsExplicitCharset())
terror.Log(err)
e.SetCoercibility(v.Coercibility())
e.GetType().SetFlag(flag)
return true, false, e
Expand Down
2 changes: 2 additions & 0 deletions expression/util_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -588,6 +588,8 @@ func (m *MockExpr) Coercibility() Coercibility
func (m *MockExpr) SetCoercibility(Coercibility) {}
func (m *MockExpr) Repertoire() Repertoire { return UNICODE }
func (m *MockExpr) SetRepertoire(Repertoire) {}
func (m *MockExpr) IsExplicitCharset() bool { return false }
func (m *MockExpr) SetExplicitCharset(bool) {}

func (m *MockExpr) CharsetAndCollation() (string, string) {
return "", ""
Expand Down
2 changes: 1 addition & 1 deletion planner/core/expression_rewriter.go
Original file line number Diff line number Diff line change
Expand Up @@ -1209,7 +1209,7 @@ func (er *expressionRewriter) Leave(originInNode ast.Node) (retNode ast.Node, ok
return retNode, false
}

castFunction, err := expression.BuildCastFunctionWithCheck(er.sctx, arg, v.Tp)
castFunction, err := expression.BuildCastFunctionWithCheck(er.sctx, arg, v.Tp, v.ExplicitCharSet)
if err != nil {
er.err = err
return retNode, false
Expand Down
30 changes: 30 additions & 0 deletions tests/integrationtest/r/expression/cast.result
Original file line number Diff line number Diff line change
Expand Up @@ -28,3 +28,33 @@ select 1.194192591e9 > t0.c0 from t0;
select 1.194192591e9 < t0.c0 from t0;
1.194192591e9 < t0.c0
0
drop table if exists test;
CREATE TABLE `test` (
`id` bigint(20) NOT NULL,
`update_user` varchar(32) DEFAULT NULL,
PRIMARY KEY (`id`) /*T![clustered_index] CLUSTERED */
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_bin;
insert into test values(1,'张三');
insert into test values(2,'李四');
insert into test values(3,'张三');
insert into test values(4,'李四');
select * from test order by cast(update_user as char character set gbk) desc , id limit 3;
id update_user
1 张三
3 张三
2 李四
drop table test;
CREATE TABLE `test` (
`id` bigint NOT NULL,
`update_user` varchar(32) CHARACTER SET gbk COLLATE gbk_chinese_ci DEFAULT NULL,
PRIMARY KEY (`id`)
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_bin;
insert into test values(1,'张三');
insert into test values(2,'李四');
insert into test values(3,'张三');
insert into test values(4,'李四');
select * from test order by cast(update_user as char) desc , id limit 3;
id update_user
2 李四
4 李四
1 张三
26 changes: 26 additions & 0 deletions tests/integrationtest/t/expression/cast.test
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,29 @@ select t0.c0 > 1.194192591e9 from t0;
select t0.c0 < 1.194192591e9 from t0;
select 1.194192591e9 > t0.c0 from t0;
select 1.194192591e9 < t0.c0 from t0;

# TestCastAsStringExplicitCharSet
drop table if exists test;
CREATE TABLE `test` (
`id` bigint(20) NOT NULL,
`update_user` varchar(32) DEFAULT NULL,
PRIMARY KEY (`id`) /*T![clustered_index] CLUSTERED */
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_bin;
insert into test values(1,'张三');
insert into test values(2,'李四');
insert into test values(3,'张三');
insert into test values(4,'李四');
select * from test order by cast(update_user as char character set gbk) desc , id limit 3;

drop table test;
CREATE TABLE `test` (
`id` bigint NOT NULL,
`update_user` varchar(32) CHARACTER SET gbk COLLATE gbk_chinese_ci DEFAULT NULL,
PRIMARY KEY (`id`)
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_bin;
insert into test values(1,'张三');
insert into test values(2,'李四');
insert into test values(3,'张三');
insert into test values(4,'李四');
select * from test order by cast(update_user as char) desc , id limit 3;