From 2e262589109e1cd4e17d1a2e95867b36a88949ac Mon Sep 17 00:00:00 2001 From: wjhuang2016 Date: Tue, 4 Jan 2022 14:17:07 +0800 Subject: [PATCH 1/7] done Signed-off-by: wjhuang2016 --- planner/core/expression_rewriter.go | 50 +++++++++-------------------- 1 file changed, 16 insertions(+), 34 deletions(-) diff --git a/planner/core/expression_rewriter.go b/planner/core/expression_rewriter.go index 08e0262613cb9..f933c2e452e73 100644 --- a/planner/core/expression_rewriter.go +++ b/planner/core/expression_rewriter.go @@ -1522,8 +1522,8 @@ func (er *expressionRewriter) inToExpression(lLen int, not bool, tp *types.Field } // deriveCollationForIn derives collation for in expression. -func (er *expressionRewriter) deriveCollationForIn(colLen int, elemCnt int, stkLen int, args []expression.Expression) []*expression.ExprCollation { - coll := make([]*expression.ExprCollation, 0, colLen) +// We don't handle the cases if the element is a tuple, such as (a, b, c) in ((x1, y1, z1), (x2, y2, z2)). +func (er *expressionRewriter) deriveCollationForIn(colLen int, elemCnt int, stkLen int, args []expression.Expression) *expression.ExprCollation { if colLen == 1 { // a in (x, y, z) => coll[0] coll2, err := expression.CheckAndDeriveCollationFromExprs(er.sctx, "IN", types.ETInt, args...) @@ -1531,46 +1531,28 @@ func (er *expressionRewriter) deriveCollationForIn(colLen int, elemCnt int, stkL if er.err != nil { return nil } - coll = append(coll, coll2) - } else { - // (a, b, c) in ((x1, x2, x3), (y1, y2, y3), (z1, z2, z3)) => coll[0], coll[1], coll[2] - for i := 0; i < colLen; i++ { - args := make([]expression.Expression, 0, elemCnt) - for j := stkLen - elemCnt - 1; j < stkLen; j++ { - rowFunc, _ := er.ctxStack[j].(*expression.ScalarFunction) - args = append(args, rowFunc.GetArgs()[i]) - } - coll2, err := expression.CheckAndDeriveCollationFromExprs(er.sctx, "IN", types.ETInt, args...) - er.err = err - if er.err != nil { - return nil - } - coll = append(coll, coll2) - } + return coll2 } - return coll + return nil } // castCollationForIn casts collation info for arguments in the `in clause` to make sure the used collation is correct after we // rewrite it to equal expression. -func (er *expressionRewriter) castCollationForIn(colLen int, elemCnt int, stkLen int, coll []*expression.ExprCollation) { - for i := stkLen - elemCnt; i < stkLen; i++ { - if colLen == 1 && er.ctxStack[i].GetType().EvalType() == types.ETString { +func (er *expressionRewriter) castCollationForIn(colLen int, elemCnt int, stkLen int, coll *expression.ExprCollation) { + // We don't handle the cases if the element is a tuple, such as (a, b, c) in ((x1, y1, z1), (x2, y2, z2)). + if colLen != 1 { + return + } + for i := stkLen - elemCnt - 1; i < stkLen; i++ { + if er.ctxStack[i].GetType().EvalType() == types.ETString { + rowFunc, ok := er.ctxStack[i].(*expression.ScalarFunction) + if ok && rowFunc.FuncName.String() == ast.RowFunc { + continue + } tp := er.ctxStack[i].GetType().Clone() - tp.Charset, tp.Collate = coll[0].Charset, coll[0].Collation + tp.Charset, tp.Collate = coll.Charset, coll.Collation er.ctxStack[i] = expression.BuildCastFunction(er.sctx, er.ctxStack[i], tp) er.ctxStack[i].SetCoercibility(expression.CoercibilityExplicit) - } else { - rowFunc, _ := er.ctxStack[i].(*expression.ScalarFunction) - for j := 0; j < colLen; j++ { - if er.ctxStack[i].GetType().EvalType() != types.ETString { - continue - } - tp := rowFunc.GetArgs()[j].GetType().Clone() - tp.Charset, tp.Collate = coll[j].Charset, coll[j].Collation - rowFunc.GetArgs()[j] = expression.BuildCastFunction(er.sctx, rowFunc.GetArgs()[j], tp) - rowFunc.GetArgs()[j].SetCoercibility(expression.CoercibilityExplicit) - } } } } From 473012b1b89466e4b40761941ad9ac6f8267455b Mon Sep 17 00:00:00 2001 From: wjhuang2016 Date: Tue, 4 Jan 2022 14:30:45 +0800 Subject: [PATCH 2/7] fix test Signed-off-by: wjhuang2016 --- expression/integration_serial_test.go | 5 ----- 1 file changed, 5 deletions(-) diff --git a/expression/integration_serial_test.go b/expression/integration_serial_test.go index 3077e2f1b33a7..3be189d90246a 100644 --- a/expression/integration_serial_test.go +++ b/expression/integration_serial_test.go @@ -180,11 +180,6 @@ func TestCollationBasic(t *testing.T) { tk.MustExec("create table t(a char(10))") tk.MustExec("insert into t values ('a')") tk.MustQuery("select * from t where a in ('b' collate utf8mb4_general_ci, 'A', 3)").Check(testkit.Rows("a")) - // These test cases may not the same as MySQL, but it's more reasonable. - tk.MustQuery("select ('a', 'a') in (('A' collate utf8mb4_general_ci, 'A' collate utf8mb4_general_ci));").Check(testkit.Rows("1")) - tk.MustQuery("select ('a', 'a') in (('A' collate utf8mb4_general_ci, 'A' collate utf8mb4_bin));").Check(testkit.Rows("0")) - tk.MustQuery("select ('a', 'a') in (('A' collate utf8mb4_general_ci, 'A' collate utf8mb4_general_ci), ('b', 'b'));").Check(testkit.Rows("1")) - tk.MustQuery("select ('a', 'a') in (('A' collate utf8mb4_general_ci, 'A' collate utf8mb4_bin), ('b', 'b'));").Check(testkit.Rows("0")) } func TestWeightString(t *testing.T) { From 8a47ce43c01e57bf498cdfd5e7837ce55d71db85 Mon Sep 17 00:00:00 2001 From: wjhuang2016 Date: Tue, 4 Jan 2022 15:56:10 +0800 Subject: [PATCH 3/7] reset Signed-off-by: wjhuang2016 --- planner/core/expression_rewriter.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/planner/core/expression_rewriter.go b/planner/core/expression_rewriter.go index f933c2e452e73..37134ba71f19d 100644 --- a/planner/core/expression_rewriter.go +++ b/planner/core/expression_rewriter.go @@ -1543,7 +1543,7 @@ func (er *expressionRewriter) castCollationForIn(colLen int, elemCnt int, stkLen if colLen != 1 { return } - for i := stkLen - elemCnt - 1; i < stkLen; i++ { + for i := stkLen - elemCnt; i < stkLen; i++ { if er.ctxStack[i].GetType().EvalType() == types.ETString { rowFunc, ok := er.ctxStack[i].(*expression.ScalarFunction) if ok && rowFunc.FuncName.String() == ast.RowFunc { From 65a0c8e0d61ba48c21c48b35424f7b4c771ae4f2 Mon Sep 17 00:00:00 2001 From: wjhuang2016 Date: Wed, 5 Jan 2022 12:19:42 +0800 Subject: [PATCH 4/7] fix Signed-off-by: wjhuang2016 --- expression/integration_serial_test.go | 6 ++++++ planner/core/expression_rewriter.go | 3 ++- 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/expression/integration_serial_test.go b/expression/integration_serial_test.go index 3be189d90246a..bf47afa87a45b 100644 --- a/expression/integration_serial_test.go +++ b/expression/integration_serial_test.go @@ -180,6 +180,12 @@ func TestCollationBasic(t *testing.T) { tk.MustExec("create table t(a char(10))") tk.MustExec("insert into t values ('a')") tk.MustQuery("select * from t where a in ('b' collate utf8mb4_general_ci, 'A', 3)").Check(testkit.Rows("a")) + + tk.MustExec("drop table if exists t") + tk.MustExec("create table t(`COL2` tinyint(16) DEFAULT NULL);") + tk.MustExec("insert into t values(0);") + tk.MustQuery("select * from t WHERE COL2 IN (0xfc);").Check(testkit.Rows()) + tk.MustQuery("select * from t WHERE COL2 = 0xfc;").Check(testkit.Rows()) } func TestWeightString(t *testing.T) { diff --git a/planner/core/expression_rewriter.go b/planner/core/expression_rewriter.go index 37134ba71f19d..2c93d14268124 100644 --- a/planner/core/expression_rewriter.go +++ b/planner/core/expression_rewriter.go @@ -1544,7 +1544,8 @@ func (er *expressionRewriter) castCollationForIn(colLen int, elemCnt int, stkLen return } for i := stkLen - elemCnt; i < stkLen; i++ { - if er.ctxStack[i].GetType().EvalType() == types.ETString { + // Don't convert it if it's charset is binary. So that we don't convert 0x12 to a string. + if er.ctxStack[i].GetType().EvalType() == types.ETString && er.ctxStack[i].GetType().Charset != charset.CharsetBin { rowFunc, ok := er.ctxStack[i].(*expression.ScalarFunction) if ok && rowFunc.FuncName.String() == ast.RowFunc { continue From 92a549d629bd51a6a1ab5d674dd0ac088abfda54 Mon Sep 17 00:00:00 2001 From: wjhuang2016 Date: Wed, 5 Jan 2022 12:26:05 +0800 Subject: [PATCH 5/7] refine Signed-off-by: wjhuang2016 --- planner/core/expression_rewriter.go | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/planner/core/expression_rewriter.go b/planner/core/expression_rewriter.go index 2c93d14268124..fffc66a6fafe7 100644 --- a/planner/core/expression_rewriter.go +++ b/planner/core/expression_rewriter.go @@ -1544,12 +1544,15 @@ func (er *expressionRewriter) castCollationForIn(colLen int, elemCnt int, stkLen return } for i := stkLen - elemCnt; i < stkLen; i++ { - // Don't convert it if it's charset is binary. So that we don't convert 0x12 to a string. - if er.ctxStack[i].GetType().EvalType() == types.ETString && er.ctxStack[i].GetType().Charset != charset.CharsetBin { + if er.ctxStack[i].GetType().EvalType() == types.ETString { rowFunc, ok := er.ctxStack[i].(*expression.ScalarFunction) if ok && rowFunc.FuncName.String() == ast.RowFunc { continue } + // Don't convert it if it's charset is binary. So that we don't convert 0x12 to a string. + if er.ctxStack[i].GetType().Collate == coll.Collation { + continue + } tp := er.ctxStack[i].GetType().Clone() tp.Charset, tp.Collate = coll.Charset, coll.Collation er.ctxStack[i] = expression.BuildCastFunction(er.sctx, er.ctxStack[i], tp) From d256c13c5cb40b0b6962d883930298f2df0e46b6 Mon Sep 17 00:00:00 2001 From: wjhuang2016 Date: Thu, 6 Jan 2022 15:02:43 +0800 Subject: [PATCH 6/7] done Signed-off-by: wjhuang2016 --- .../r/collation_check_use_collation.result | 42 +++++++++++++++++++ .../t/collation_check_use_collation.test | 30 +++++++++++++ planner/core/expression_rewriter.go | 3 ++ 3 files changed, 75 insertions(+) diff --git a/cmd/explaintest/r/collation_check_use_collation.result b/cmd/explaintest/r/collation_check_use_collation.result index ffd787a4cef43..ff5e1e6a54dca 100644 --- a/cmd/explaintest/r/collation_check_use_collation.result +++ b/cmd/explaintest/r/collation_check_use_collation.result @@ -22,4 +22,46 @@ a select a as a_col from t where t.a = (select a collate utf8mb4_general_ci from t1); a_col a +drop table if exists t; +create table t(a enum('a', 'b'), b varchar(20)); +insert into t values ("a", "b"); +select * from t where a in (a); +a b +a b +drop table if exists t; +create table t(a enum('a', 'b') charset utf8mb4 collate utf8mb4_general_ci, b varchar(20)); +insert into t values ("B", "b"); +select * from t where 'B' collate utf8mb4_general_ci in (a); +a b +b b +select * from t where 'B' collate utf8mb4_bin in (a); +a b +select * from t where 'B' collate utf8mb4_bin in (a, b); +a b +select * from t where 'B' collate utf8mb4_bin in (a, "a", 1); +a b +select * from t where 'B' collate utf8mb4_bin in (a, "B", 1); +a b +b b +drop table if exists t; +create table t(a set('a', 'b'), b varchar(20)); +insert into t values ("a", "b"); +select * from t where a in (a); +a b +a b +drop table if exists t; +create table t(a set('a', 'b') charset utf8mb4 collate utf8mb4_general_ci, b varchar(20)); +insert into t values ("B", "b"); +select * from t where 'B' collate utf8mb4_general_ci in (a); +a b +b b +select * from t where 'B' collate utf8mb4_bin in (a); +a b +select * from t where 'B' collate utf8mb4_bin in (a, b); +a b +select * from t where 'B' collate utf8mb4_bin in (a, "a", 1); +a b +select * from t where 'B' collate utf8mb4_bin in (a, "B", 1); +a b +b b use test diff --git a/cmd/explaintest/t/collation_check_use_collation.test b/cmd/explaintest/t/collation_check_use_collation.test index 67e75f32e38f9..caf616ed554d6 100644 --- a/cmd/explaintest/t/collation_check_use_collation.test +++ b/cmd/explaintest/t/collation_check_use_collation.test @@ -19,5 +19,35 @@ select a as a_col from t where t.a <= all (select a collate utf8mb4_general_ci f select a as a_col from t where t.a <= any (select a collate utf8mb4_general_ci from t1); select a as a_col from t where t.a = (select a collate utf8mb4_general_ci from t1); +## Check rewrite in expression + +# enum part +drop table if exists t; +create table t(a enum('a', 'b'), b varchar(20)); +insert into t values ("a", "b"); +select * from t where a in (a); +drop table if exists t; +create table t(a enum('a', 'b') charset utf8mb4 collate utf8mb4_general_ci, b varchar(20)); +insert into t values ("B", "b"); +select * from t where 'B' collate utf8mb4_general_ci in (a); +select * from t where 'B' collate utf8mb4_bin in (a); +select * from t where 'B' collate utf8mb4_bin in (a, b); +select * from t where 'B' collate utf8mb4_bin in (a, "a", 1); +select * from t where 'B' collate utf8mb4_bin in (a, "B", 1); + +# set part +drop table if exists t; +create table t(a set('a', 'b'), b varchar(20)); +insert into t values ("a", "b"); +select * from t where a in (a); +drop table if exists t; +create table t(a set('a', 'b') charset utf8mb4 collate utf8mb4_general_ci, b varchar(20)); +insert into t values ("B", "b"); +select * from t where 'B' collate utf8mb4_general_ci in (a); +select * from t where 'B' collate utf8mb4_bin in (a); +select * from t where 'B' collate utf8mb4_bin in (a, b); +select * from t where 'B' collate utf8mb4_bin in (a, "a", 1); +select * from t where 'B' collate utf8mb4_bin in (a, "B", 1); + # cleanup environment use test diff --git a/planner/core/expression_rewriter.go b/planner/core/expression_rewriter.go index fffc66a6fafe7..c67d2f54c9faf 100644 --- a/planner/core/expression_rewriter.go +++ b/planner/core/expression_rewriter.go @@ -1554,6 +1554,9 @@ func (er *expressionRewriter) castCollationForIn(colLen int, elemCnt int, stkLen continue } tp := er.ctxStack[i].GetType().Clone() + if er.ctxStack[i].GetType().Hybrid() && expression.GetAccurateCmpType(er.ctxStack[stkLen-elemCnt-1], er.ctxStack[i]) == types.ETString { + tp = types.NewFieldType(mysql.TypeVarString) + } tp.Charset, tp.Collate = coll.Charset, coll.Collation er.ctxStack[i] = expression.BuildCastFunction(er.sctx, er.ctxStack[i], tp) er.ctxStack[i].SetCoercibility(expression.CoercibilityExplicit) From 5bf3288699984ab09b5346e644604bbbc5b9f43e Mon Sep 17 00:00:00 2001 From: wjhuang2016 Date: Thu, 6 Jan 2022 15:13:38 +0800 Subject: [PATCH 7/7] done Signed-off-by: wjhuang2016 --- .../r/collation_check_use_collation.result | 14 ++++++++++++++ .../t/collation_check_use_collation.test | 6 ++++++ planner/core/expression_rewriter.go | 8 ++++++-- 3 files changed, 26 insertions(+), 2 deletions(-) diff --git a/cmd/explaintest/r/collation_check_use_collation.result b/cmd/explaintest/r/collation_check_use_collation.result index ad1fd3da5ace6..30a766eecfda6 100644 --- a/cmd/explaintest/r/collation_check_use_collation.result +++ b/cmd/explaintest/r/collation_check_use_collation.result @@ -43,6 +43,13 @@ a b select * from t where 'B' collate utf8mb4_bin in (a, "B", 1); a b b b +select * from t where 1 in (a); +a b +select * from t where 2 in (a); +a b +b b +select * from t where 1 in (a, 0); +a b drop table if exists t; create table t(a set('a', 'b'), b varchar(20)); insert into t values ("a", "b"); @@ -64,6 +71,13 @@ a b select * from t where 'B' collate utf8mb4_bin in (a, "B", 1); a b b b +select * from t where 1 in (a); +a b +select * from t where 2 in (a); +a b +b b +select * from t where 1 in (a, 0); +a b drop table if exists tbl_2; create table tbl_2 ( col_20 bigint not null , col_21 smallint not null , col_22 decimal(24,10) default null , col_23 tinyint default 71 not null , col_24 bigint not null , col_25 tinyint default 18 , col_26 varchar(330) collate utf8_bin not null , col_27 char(77) collate utf8mb4_unicode_ci , col_28 char(46) collate utf8_general_ci not null , col_29 smallint unsigned not null , primary key idx_13 ( col_27(5) ) , key idx_14 ( col_24 ) , unique key idx_15 ( col_23,col_21,col_28,col_29,col_24 ) ) collate utf8_bin ; insert ignore into tbl_2 values ( 5888267793391993829,5371,94.63,-109,5728076076919247337,89,'WUicqUTgdGJcjbC','SapBPqczTWWSN','xUSwH',49462 ); diff --git a/cmd/explaintest/t/collation_check_use_collation.test b/cmd/explaintest/t/collation_check_use_collation.test index 1ff4601001791..97ac87f681cf9 100644 --- a/cmd/explaintest/t/collation_check_use_collation.test +++ b/cmd/explaintest/t/collation_check_use_collation.test @@ -34,6 +34,9 @@ select * from t where 'B' collate utf8mb4_bin in (a); select * from t where 'B' collate utf8mb4_bin in (a, b); select * from t where 'B' collate utf8mb4_bin in (a, "a", 1); select * from t where 'B' collate utf8mb4_bin in (a, "B", 1); +select * from t where 1 in (a); +select * from t where 2 in (a); +select * from t where 1 in (a, 0); # set part drop table if exists t; @@ -48,6 +51,9 @@ select * from t where 'B' collate utf8mb4_bin in (a); select * from t where 'B' collate utf8mb4_bin in (a, b); select * from t where 'B' collate utf8mb4_bin in (a, "a", 1); select * from t where 'B' collate utf8mb4_bin in (a, "B", 1); +select * from t where 1 in (a); +select * from t where 2 in (a); +select * from t where 1 in (a, 0); # check build range drop table if exists tbl_2; diff --git a/planner/core/expression_rewriter.go b/planner/core/expression_rewriter.go index c67d2f54c9faf..ab6f293a7bf63 100644 --- a/planner/core/expression_rewriter.go +++ b/planner/core/expression_rewriter.go @@ -1554,8 +1554,12 @@ func (er *expressionRewriter) castCollationForIn(colLen int, elemCnt int, stkLen continue } tp := er.ctxStack[i].GetType().Clone() - if er.ctxStack[i].GetType().Hybrid() && expression.GetAccurateCmpType(er.ctxStack[stkLen-elemCnt-1], er.ctxStack[i]) == types.ETString { - tp = types.NewFieldType(mysql.TypeVarString) + if er.ctxStack[i].GetType().Hybrid() { + if expression.GetAccurateCmpType(er.ctxStack[stkLen-elemCnt-1], er.ctxStack[i]) == types.ETString { + tp = types.NewFieldType(mysql.TypeVarString) + } else { + continue + } } tp.Charset, tp.Collate = coll.Charset, coll.Collation er.ctxStack[i] = expression.BuildCastFunction(er.sctx, er.ctxStack[i], tp)