diff --git a/cmd/explaintest/r/collation_check_use_collation.result b/cmd/explaintest/r/collation_check_use_collation.result index 84f0f583ae796..30a766eecfda6 100644 --- a/cmd/explaintest/r/collation_check_use_collation.result +++ b/cmd/explaintest/r/collation_check_use_collation.result @@ -22,6 +22,62 @@ a select a as a_col from t where t.a = (select a collate utf8mb4_general_ci from t1); a_col a +drop table if exists t; +create table t(a enum('a', 'b'), b varchar(20)); +insert into t values ("a", "b"); +select * from t where a in (a); +a b +a b +drop table if exists t; +create table t(a enum('a', 'b') charset utf8mb4 collate utf8mb4_general_ci, b varchar(20)); +insert into t values ("B", "b"); +select * from t where 'B' collate utf8mb4_general_ci in (a); +a b +b b +select * from t where 'B' collate utf8mb4_bin in (a); +a b +select * from t where 'B' collate utf8mb4_bin in (a, b); +a b +select * from t where 'B' collate utf8mb4_bin in (a, "a", 1); +a b +select * from t where 'B' collate utf8mb4_bin in (a, "B", 1); +a b +b b +select * from t where 1 in (a); +a b +select * from t where 2 in (a); +a b +b b +select * from t where 1 in (a, 0); +a b +drop table if exists t; +create table t(a set('a', 'b'), b varchar(20)); +insert into t values ("a", "b"); +select * from t where a in (a); +a b +a b +drop table if exists t; +create table t(a set('a', 'b') charset utf8mb4 collate utf8mb4_general_ci, b varchar(20)); +insert into t values ("B", "b"); +select * from t where 'B' collate utf8mb4_general_ci in (a); +a b +b b +select * from t where 'B' collate utf8mb4_bin in (a); +a b +select * from t where 'B' collate utf8mb4_bin in (a, b); +a b +select * from t where 'B' collate utf8mb4_bin in (a, "a", 1); +a b +select * from t where 'B' collate utf8mb4_bin in (a, "B", 1); +a b +b b +select * from t where 1 in (a); +a b +select * from t where 2 in (a); +a b +b b +select * from t where 1 in (a, 0); +a b drop table if exists tbl_2; create table tbl_2 ( col_20 bigint not null , col_21 smallint not null , col_22 decimal(24,10) default null , col_23 tinyint default 71 not null , col_24 bigint not null , col_25 tinyint default 18 , col_26 varchar(330) collate utf8_bin not null , col_27 char(77) collate utf8mb4_unicode_ci , col_28 char(46) collate utf8_general_ci not null , col_29 smallint unsigned not null , primary key idx_13 ( col_27(5) ) , key idx_14 ( col_24 ) , unique key idx_15 ( col_23,col_21,col_28,col_29,col_24 ) ) collate utf8_bin ; insert ignore into tbl_2 values ( 5888267793391993829,5371,94.63,-109,5728076076919247337,89,'WUicqUTgdGJcjbC','SapBPqczTWWSN','xUSwH',49462 ); diff --git a/cmd/explaintest/t/collation_check_use_collation.test b/cmd/explaintest/t/collation_check_use_collation.test index 0fd8f373f8e76..97ac87f681cf9 100644 --- a/cmd/explaintest/t/collation_check_use_collation.test +++ b/cmd/explaintest/t/collation_check_use_collation.test @@ -19,6 +19,43 @@ select a as a_col from t where t.a <= all (select a collate utf8mb4_general_ci f select a as a_col from t where t.a <= any (select a collate utf8mb4_general_ci from t1); select a as a_col from t where t.a = (select a collate utf8mb4_general_ci from t1); +## Check rewrite in expression + +# enum part +drop table if exists t; +create table t(a enum('a', 'b'), b varchar(20)); +insert into t values ("a", "b"); +select * from t where a in (a); +drop table if exists t; +create table t(a enum('a', 'b') charset utf8mb4 collate utf8mb4_general_ci, b varchar(20)); +insert into t values ("B", "b"); +select * from t where 'B' collate utf8mb4_general_ci in (a); +select * from t where 'B' collate utf8mb4_bin in (a); +select * from t where 'B' collate utf8mb4_bin in (a, b); +select * from t where 'B' collate utf8mb4_bin in (a, "a", 1); +select * from t where 'B' collate utf8mb4_bin in (a, "B", 1); +select * from t where 1 in (a); +select * from t where 2 in (a); +select * from t where 1 in (a, 0); + +# set part +drop table if exists t; +create table t(a set('a', 'b'), b varchar(20)); +insert into t values ("a", "b"); +select * from t where a in (a); +drop table if exists t; +create table t(a set('a', 'b') charset utf8mb4 collate utf8mb4_general_ci, b varchar(20)); +insert into t values ("B", "b"); +select * from t where 'B' collate utf8mb4_general_ci in (a); +select * from t where 'B' collate utf8mb4_bin in (a); +select * from t where 'B' collate utf8mb4_bin in (a, b); +select * from t where 'B' collate utf8mb4_bin in (a, "a", 1); +select * from t where 'B' collate utf8mb4_bin in (a, "B", 1); +select * from t where 1 in (a); +select * from t where 2 in (a); +select * from t where 1 in (a, 0); + +# check build range drop table if exists tbl_2; create table tbl_2 ( col_20 bigint not null , col_21 smallint not null , col_22 decimal(24,10) default null , col_23 tinyint default 71 not null , col_24 bigint not null , col_25 tinyint default 18 , col_26 varchar(330) collate utf8_bin not null , col_27 char(77) collate utf8mb4_unicode_ci , col_28 char(46) collate utf8_general_ci not null , col_29 smallint unsigned not null , primary key idx_13 ( col_27(5) ) , key idx_14 ( col_24 ) , unique key idx_15 ( col_23,col_21,col_28,col_29,col_24 ) ) collate utf8_bin ; insert ignore into tbl_2 values ( 5888267793391993829,5371,94.63,-109,5728076076919247337,89,'WUicqUTgdGJcjbC','SapBPqczTWWSN','xUSwH',49462 ); diff --git a/expression/integration_serial_test.go b/expression/integration_serial_test.go index cb9dd3bed32a5..2ac3e40830ec7 100644 --- a/expression/integration_serial_test.go +++ b/expression/integration_serial_test.go @@ -180,11 +180,12 @@ func TestCollationBasic(t *testing.T) { tk.MustExec("create table t(a char(10))") tk.MustExec("insert into t values ('a')") tk.MustQuery("select * from t where a in ('b' collate utf8mb4_general_ci, 'A', 3)").Check(testkit.Rows("a")) - // These test cases may not the same as MySQL, but it's more reasonable. - tk.MustQuery("select ('a', 'a') in (('A' collate utf8mb4_general_ci, 'A' collate utf8mb4_general_ci));").Check(testkit.Rows("1")) - tk.MustQuery("select ('a', 'a') in (('A' collate utf8mb4_general_ci, 'A' collate utf8mb4_bin));").Check(testkit.Rows("0")) - tk.MustQuery("select ('a', 'a') in (('A' collate utf8mb4_general_ci, 'A' collate utf8mb4_general_ci), ('b', 'b'));").Check(testkit.Rows("1")) - tk.MustQuery("select ('a', 'a') in (('A' collate utf8mb4_general_ci, 'A' collate utf8mb4_bin), ('b', 'b'));").Check(testkit.Rows("0")) + + tk.MustExec("drop table if exists t") + tk.MustExec("create table t(`COL2` tinyint(16) DEFAULT NULL);") + tk.MustExec("insert into t values(0);") + tk.MustQuery("select * from t WHERE COL2 IN (0xfc);").Check(testkit.Rows()) + tk.MustQuery("select * from t WHERE COL2 = 0xfc;").Check(testkit.Rows()) } func TestWeightString(t *testing.T) { diff --git a/planner/core/expression_rewriter.go b/planner/core/expression_rewriter.go index 08e0262613cb9..ab6f293a7bf63 100644 --- a/planner/core/expression_rewriter.go +++ b/planner/core/expression_rewriter.go @@ -1522,8 +1522,8 @@ func (er *expressionRewriter) inToExpression(lLen int, not bool, tp *types.Field } // deriveCollationForIn derives collation for in expression. -func (er *expressionRewriter) deriveCollationForIn(colLen int, elemCnt int, stkLen int, args []expression.Expression) []*expression.ExprCollation { - coll := make([]*expression.ExprCollation, 0, colLen) +// We don't handle the cases if the element is a tuple, such as (a, b, c) in ((x1, y1, z1), (x2, y2, z2)). +func (er *expressionRewriter) deriveCollationForIn(colLen int, elemCnt int, stkLen int, args []expression.Expression) *expression.ExprCollation { if colLen == 1 { // a in (x, y, z) => coll[0] coll2, err := expression.CheckAndDeriveCollationFromExprs(er.sctx, "IN", types.ETInt, args...) @@ -1531,46 +1531,39 @@ func (er *expressionRewriter) deriveCollationForIn(colLen int, elemCnt int, stkL if er.err != nil { return nil } - coll = append(coll, coll2) - } else { - // (a, b, c) in ((x1, x2, x3), (y1, y2, y3), (z1, z2, z3)) => coll[0], coll[1], coll[2] - for i := 0; i < colLen; i++ { - args := make([]expression.Expression, 0, elemCnt) - for j := stkLen - elemCnt - 1; j < stkLen; j++ { - rowFunc, _ := er.ctxStack[j].(*expression.ScalarFunction) - args = append(args, rowFunc.GetArgs()[i]) - } - coll2, err := expression.CheckAndDeriveCollationFromExprs(er.sctx, "IN", types.ETInt, args...) - er.err = err - if er.err != nil { - return nil - } - coll = append(coll, coll2) - } + return coll2 } - return coll + return nil } // castCollationForIn casts collation info for arguments in the `in clause` to make sure the used collation is correct after we // rewrite it to equal expression. -func (er *expressionRewriter) castCollationForIn(colLen int, elemCnt int, stkLen int, coll []*expression.ExprCollation) { +func (er *expressionRewriter) castCollationForIn(colLen int, elemCnt int, stkLen int, coll *expression.ExprCollation) { + // We don't handle the cases if the element is a tuple, such as (a, b, c) in ((x1, y1, z1), (x2, y2, z2)). + if colLen != 1 { + return + } for i := stkLen - elemCnt; i < stkLen; i++ { - if colLen == 1 && er.ctxStack[i].GetType().EvalType() == types.ETString { + if er.ctxStack[i].GetType().EvalType() == types.ETString { + rowFunc, ok := er.ctxStack[i].(*expression.ScalarFunction) + if ok && rowFunc.FuncName.String() == ast.RowFunc { + continue + } + // Don't convert it if it's charset is binary. So that we don't convert 0x12 to a string. + if er.ctxStack[i].GetType().Collate == coll.Collation { + continue + } tp := er.ctxStack[i].GetType().Clone() - tp.Charset, tp.Collate = coll[0].Charset, coll[0].Collation - er.ctxStack[i] = expression.BuildCastFunction(er.sctx, er.ctxStack[i], tp) - er.ctxStack[i].SetCoercibility(expression.CoercibilityExplicit) - } else { - rowFunc, _ := er.ctxStack[i].(*expression.ScalarFunction) - for j := 0; j < colLen; j++ { - if er.ctxStack[i].GetType().EvalType() != types.ETString { + if er.ctxStack[i].GetType().Hybrid() { + if expression.GetAccurateCmpType(er.ctxStack[stkLen-elemCnt-1], er.ctxStack[i]) == types.ETString { + tp = types.NewFieldType(mysql.TypeVarString) + } else { continue } - tp := rowFunc.GetArgs()[j].GetType().Clone() - tp.Charset, tp.Collate = coll[j].Charset, coll[j].Collation - rowFunc.GetArgs()[j] = expression.BuildCastFunction(er.sctx, rowFunc.GetArgs()[j], tp) - rowFunc.GetArgs()[j].SetCoercibility(expression.CoercibilityExplicit) } + tp.Charset, tp.Collate = coll.Charset, coll.Collation + er.ctxStack[i] = expression.BuildCastFunction(er.sctx, er.ctxStack[i], tp) + er.ctxStack[i].SetCoercibility(expression.CoercibilityExplicit) } } }