diff --git a/cmd/explaintest/r/subquery.result b/cmd/explaintest/r/subquery.result index 2bf71b91088cb..05e12ff9f518b 100644 --- a/cmd/explaintest/r/subquery.result +++ b/cmd/explaintest/r/subquery.result @@ -46,3 +46,31 @@ create table t1(a int(11)); create table t2(a decimal(40,20) unsigned, b decimal(40,20)); select count(*) as x from t1 group by a having x not in (select a from t2 where x = t2.b); x +drop table if exists stu; +drop table if exists exam; +create table stu(id int, name varchar(100)); +insert into stu values(1, null); +create table exam(stu_id int, course varchar(100), grade int); +insert into exam values(1, 'math', 100); +set names utf8 collate utf8_general_ci; +explain format = 'brief' select * from stu where stu.name not in (select 'guo' from exam where exam.stu_id = stu.id); +id estRows task access object operator info +Apply 10000.00 root CARTESIAN anti semi join, other cond:eq(test.stu.name, Column#8) +├─TableReader(Build) 10000.00 root data:TableFullScan +│ └─TableFullScan 10000.00 cop[tikv] table:stu keep order:false, stats:pseudo +└─Projection(Probe) 10.00 root guo->Column#8 + └─TableReader 10.00 root data:Selection + └─Selection 10.00 cop[tikv] eq(test.exam.stu_id, test.stu.id) + └─TableFullScan 10000.00 cop[tikv] table:exam keep order:false, stats:pseudo +select * from stu where stu.name not in (select 'guo' from exam where exam.stu_id = stu.id); +id name +set names utf8mb4; +explain format = 'brief' select * from stu where stu.name not in (select 'guo' from exam where exam.stu_id = stu.id); +id estRows task access object operator info +HashJoin 8000.00 root anti semi join, equal:[eq(test.stu.id, test.exam.stu_id)], other cond:eq(test.stu.name, "guo") +├─TableReader(Build) 10000.00 root data:TableFullScan +│ └─TableFullScan 10000.00 cop[tikv] table:exam keep order:false, stats:pseudo +└─TableReader(Probe) 10000.00 root data:TableFullScan + └─TableFullScan 10000.00 cop[tikv] table:stu keep order:false, stats:pseudo +select * from stu where stu.name not in (select 'guo' from exam where exam.stu_id = stu.id); +id name diff --git a/cmd/explaintest/t/subquery.test b/cmd/explaintest/t/subquery.test index 6a3aa13e7e95a..5127c0e4260fa 100644 --- a/cmd/explaintest/t/subquery.test +++ b/cmd/explaintest/t/subquery.test @@ -20,3 +20,16 @@ drop table if exists t1, t2; create table t1(a int(11)); create table t2(a decimal(40,20) unsigned, b decimal(40,20)); select count(*) as x from t1 group by a having x not in (select a from t2 where x = t2.b); + +drop table if exists stu; +drop table if exists exam; +create table stu(id int, name varchar(100)); +insert into stu values(1, null); +create table exam(stu_id int, course varchar(100), grade int); +insert into exam values(1, 'math', 100); +set names utf8 collate utf8_general_ci; +explain format = 'brief' select * from stu where stu.name not in (select 'guo' from exam where exam.stu_id = stu.id); +select * from stu where stu.name not in (select 'guo' from exam where exam.stu_id = stu.id); +set names utf8mb4; +explain format = 'brief' select * from stu where stu.name not in (select 'guo' from exam where exam.stu_id = stu.id); +select * from stu where stu.name not in (select 'guo' from exam where exam.stu_id = stu.id); diff --git a/expression/util.go b/expression/util.go index d13f7f1e446e3..852cfc076dd45 100644 --- a/expression/util.go +++ b/expression/util.go @@ -181,7 +181,8 @@ func extractColumnSet(expr Expression, set *intsets.Sparse) { } } -func setExprColumnInOperand(expr Expression) Expression { +// SetExprColumnInOperand is used to set columns in expr as InOperand. +func SetExprColumnInOperand(expr Expression) Expression { switch v := expr.(type) { case *Column: col := v.Clone().(*Column) @@ -190,7 +191,7 @@ func setExprColumnInOperand(expr Expression) Expression { case *ScalarFunction: args := v.GetArgs() for i, arg := range args { - args[i] = setExprColumnInOperand(arg) + args[i] = SetExprColumnInOperand(arg) } } return expr @@ -199,30 +200,61 @@ func setExprColumnInOperand(expr Expression) Expression { // ColumnSubstitute substitutes the columns in filter to expressions in select fields. // e.g. select * from (select b as a from t) k where a < 10 => select * from (select b as a from t where b < 10) k. func ColumnSubstitute(expr Expression, schema *Schema, newExprs []Expression) Expression { - _, resExpr := ColumnSubstituteImpl(expr, schema, newExprs) + _, _, resExpr := ColumnSubstituteImpl(expr, schema, newExprs, false) return resExpr } +// ColumnSubstituteAll substitutes the columns just like ColumnSubstitute, but we don't accept partial substitution. +// Only accept: +// +// 1: substitute them all once find col in schema. +// 2: nothing in expr can be substituted. +func ColumnSubstituteAll(expr Expression, schema *Schema, newExprs []Expression) (bool, Expression) { + _, hasFail, resExpr := ColumnSubstituteImpl(expr, schema, newExprs, true) + return hasFail, resExpr +} + // ColumnSubstituteImpl tries to substitute column expr using newExprs, // the newFunctionInternal is only called if its child is substituted -func ColumnSubstituteImpl(expr Expression, schema *Schema, newExprs []Expression) (bool, Expression) { +// @return bool means whether the expr has changed. +// @return bool means whether the expr should change (has the dependency in schema, while the corresponding expr has some compatibility), but finally fallback. +// @return Expression, the original expr or the changed expr, it depends on the first @return bool. +func ColumnSubstituteImpl(expr Expression, schema *Schema, newExprs []Expression, fail1Return bool) (bool, bool, Expression) { switch v := expr.(type) { case *Column: id := schema.ColumnIndex(v) if id == -1 { - return false, v + return false, false, v } newExpr := newExprs[id] if v.InOperand { - newExpr = setExprColumnInOperand(newExpr) + newExpr = SetExprColumnInOperand(newExpr) } newExpr.SetCoercibility(v.Coercibility()) - return true, newExpr + return true, false, newExpr case *ScalarFunction: +<<<<<<< HEAD if v.FuncName.L == ast.Cast { newFunc := v.Clone().(*ScalarFunction) _, newFunc.GetArgs()[0] = ColumnSubstituteImpl(newFunc.GetArgs()[0], schema, newExprs) return true, newFunc +======= + substituted := false + hasFail := false + if v.FuncName.L == ast.Cast { + newFunc := v.Clone().(*ScalarFunction) + substituted, hasFail, newFunc.GetArgs()[0] = ColumnSubstituteImpl(newFunc.GetArgs()[0], schema, newExprs, fail1Return) + if fail1Return && hasFail { + return substituted, hasFail, newFunc + } + if substituted { + // Workaround for issue https://github.com/pingcap/tidb/issues/28804 + e := NewFunctionInternal(v.GetCtx(), v.FuncName.L, v.RetType, newFunc.GetArgs()...) + e.SetCoercibility(v.Coercibility()) + return true, false, e + } + return false, false, newFunc +>>>>>>> d3483026e... planner: mark the both side operand of NAAJ & refuse partial column substitute in projection elimination of Apply de-correlation (#37117) } // cowExprRef is a copy-on-write util, args array allocation happens only // when expr in args is changed @@ -230,7 +262,11 @@ func ColumnSubstituteImpl(expr Expression, schema *Schema, newExprs []Expression substituted := false _, coll := DeriveCollationFromExprs(v.GetCtx(), v.GetArgs()...) for idx, arg := range v.GetArgs() { - changed, newFuncExpr := ColumnSubstituteImpl(arg, schema, newExprs) + changed, hasFail, newFuncExpr := ColumnSubstituteImpl(arg, schema, newExprs, fail1Return) + if fail1Return && hasFail { + return changed, hasFail, v + } + oldChanged := changed if collate.NewCollationEnabled() { // Make sure the collation used by the ScalarFunction isn't changed and its result collation is not weaker than the collation used by the ScalarFunction. if changed { @@ -243,16 +279,24 @@ func ColumnSubstituteImpl(expr Expression, schema *Schema, newExprs []Expression } } } + if fail1Return && oldChanged != changed { + // Only when the oldChanged is true and changed is false, we will get here. + // And this means there some dependency in this arg can be substituted with + // given expressions, while it has some collation compatibility, finally we + // fall back to use the origin args. (commonly used in projection elimination + // in which fallback usage is unacceptable) + return changed, true, v + } refExprArr.Set(idx, changed, newFuncExpr) if changed { substituted = true } } if substituted { - return true, NewFunctionInternal(v.GetCtx(), v.FuncName.L, v.RetType, refExprArr.Result()...) + return true, false, NewFunctionInternal(v.GetCtx(), v.FuncName.L, v.RetType, refExprArr.Result()...) } } - return false, expr + return false, false, expr } // checkCollationStrictness check collation strictness-ship between `coll` and `newFuncColl` diff --git a/expression/util_test.go b/expression/util_test.go index e0fc2ad25e7e0..8cc01800dcdac 100644 --- a/expression/util_test.go +++ b/expression/util_test.go @@ -210,13 +210,22 @@ func (s *testUtilSuite) TestGetUint64FromConstant(c *check.C) { func (s *testUtilSuite) TestSetExprColumnInOperand(c *check.C) { col := &Column{RetType: newIntFieldType()} +<<<<<<< HEAD c.Assert(setExprColumnInOperand(col).(*Column).InOperand, check.IsTrue) +======= + require.True(t, SetExprColumnInOperand(col).(*Column).InOperand) +>>>>>>> d3483026e... planner: mark the both side operand of NAAJ & refuse partial column substitute in projection elimination of Apply de-correlation (#37117) f, err := funcs[ast.Abs].getFunction(mock.NewContext(), []Expression{col}) c.Assert(err, check.IsNil) fun := &ScalarFunction{Function: f} +<<<<<<< HEAD setExprColumnInOperand(fun) c.Assert(f.getArgs()[0].(*Column).InOperand, check.IsTrue) +======= + SetExprColumnInOperand(fun) + require.True(t, f.getArgs()[0].(*Column).InOperand) +>>>>>>> d3483026e... planner: mark the both side operand of NAAJ & refuse partial column substitute in projection elimination of Apply de-correlation (#37117) } func (s testUtilSuite) TestPopRowFirstArg(c *check.C) { diff --git a/planner/core/expression_rewriter.go b/planner/core/expression_rewriter.go index 666484f8bef0b..b471daa35d59f 100644 --- a/planner/core/expression_rewriter.go +++ b/planner/core/expression_rewriter.go @@ -461,6 +461,7 @@ func (er *expressionRewriter) buildSemiApplyFromEqualSubq(np LogicalPlan, l, r e rColCopy := *rCol rColCopy.InOperand = true r = &rColCopy + l = expression.SetExprColumnInOperand(l) } } else { rowFunc := r.(*expression.ScalarFunction) @@ -483,6 +484,7 @@ func (er *expressionRewriter) buildSemiApplyFromEqualSubq(np LogicalPlan, l, r e if er.err != nil { return } + l = expression.SetExprColumnInOperand(l) } } } @@ -864,21 +866,40 @@ func (er *expressionRewriter) handleInSubquery(ctx context.Context, v *ast.Patte // normal column equal condition, so we specially mark the inner operand here. if v.Not || asScalar { // If both input columns of `in` expression are not null, we can treat the expression - // as normal column equal condition instead. + // as normal column equal condition instead. Otherwise, mark the left and right side. + // eg: for some optimization, the column substitute in right side in projection elimination + // will cause case like as which is not + // a valid null-aware EQ. (null in lcol still need to be null-aware) if !expression.ExprNotNull(lexpr) || !expression.ExprNotNull(rCol) { rColCopy := *rCol rColCopy.InOperand = true rexpr = &rColCopy + lexpr = expression.SetExprColumnInOperand(lexpr) } } } else { args := make([]expression.Expression, 0, np.Schema().Len()) for i, col := range np.Schema().Columns { +<<<<<<< HEAD larg := expression.GetFuncArg(lexpr, i) if !expression.ExprNotNull(larg) || !expression.ExprNotNull(col) { rarg := *col rarg.InOperand = true col = &rarg +======= + if v.Not || asScalar { + larg := expression.GetFuncArg(lexpr, i) + // If both input columns of `in` expression are not null, we can treat the expression + // as normal column equal condition instead. Otherwise, mark the left and right side. + if !expression.ExprNotNull(larg) || !expression.ExprNotNull(col) { + rarg := *col + rarg.InOperand = true + col = &rarg + if larg != nil { + lexpr.(*expression.ScalarFunction).GetArgs()[i] = expression.SetExprColumnInOperand(larg) + } + } +>>>>>>> d3483026e... planner: mark the both side operand of NAAJ & refuse partial column substitute in projection elimination of Apply de-correlation (#37117) } args = append(args, col) } diff --git a/planner/core/logical_plans.go b/planner/core/logical_plans.go index 8212375dbb8f9..16686b05a8868 100644 --- a/planner/core/logical_plans.go +++ b/planner/core/logical_plans.go @@ -184,21 +184,86 @@ func (p *LogicalJoin) GetJoinKeys() (leftKeys, rightKeys []*expression.Column, i return } +<<<<<<< HEAD func (p *LogicalJoin) columnSubstitute(schema *expression.Schema, exprs []expression.Expression) { - for i, cond := range p.LeftConditions { - p.LeftConditions[i] = expression.ColumnSubstitute(cond, schema, exprs) +======= +// GetPotentialPartitionKeys return potential partition keys for join, the potential partition keys are +// the join keys of EqualConditions +func (p *LogicalJoin) GetPotentialPartitionKeys() (leftKeys, rightKeys []*property.MPPPartitionColumn) { + for _, expr := range p.EqualConditions { + _, coll := expr.CharsetAndCollation() + collateID := property.GetCollateIDByNameForPartition(coll) + leftKeys = append(leftKeys, &property.MPPPartitionColumn{Col: expr.GetArgs()[0].(*expression.Column), CollateID: collateID}) + rightKeys = append(rightKeys, &property.MPPPartitionColumn{Col: expr.GetArgs()[1].(*expression.Column), CollateID: collateID}) } + return +} +// decorrelate eliminate the correlated column with if the col is in schema. +func (p *LogicalJoin) decorrelate(schema *expression.Schema) { +>>>>>>> d3483026e... planner: mark the both side operand of NAAJ & refuse partial column substitute in projection elimination of Apply de-correlation (#37117) + for i, cond := range p.LeftConditions { + p.LeftConditions[i] = cond.Decorrelate(schema) + } for i, cond := range p.RightConditions { - p.RightConditions[i] = expression.ColumnSubstitute(cond, schema, exprs) + p.RightConditions[i] = cond.Decorrelate(schema) } - for i, cond := range p.OtherConditions { - p.OtherConditions[i] = expression.ColumnSubstitute(cond, schema, exprs) + p.OtherConditions[i] = cond.Decorrelate(schema) + } + for i, cond := range p.EqualConditions { + p.EqualConditions[i] = cond.Decorrelate(schema).(*expression.ScalarFunction) + } +} + +// columnSubstituteAll is used in projection elimination in apply de-correlation. +// Substitutions for all conditions should be successful, otherwise, we should keep all conditions unchanged. +func (p *LogicalJoin) columnSubstituteAll(schema *expression.Schema, exprs []expression.Expression) (hasFail bool) { + // make a copy of exprs for convenience of substitution (may change/partially change the expr tree) + cpLeftConditions := make(expression.CNFExprs, len(p.LeftConditions)) + cpRightConditions := make(expression.CNFExprs, len(p.RightConditions)) + cpOtherConditions := make(expression.CNFExprs, len(p.OtherConditions)) + cpEqualConditions := make([]*expression.ScalarFunction, len(p.EqualConditions)) + copy(cpLeftConditions, p.LeftConditions) + copy(cpRightConditions, p.RightConditions) + copy(cpOtherConditions, p.OtherConditions) + copy(cpEqualConditions, p.EqualConditions) + + // try to substitute columns in these condition. + for i, cond := range cpLeftConditions { + if hasFail, cpLeftConditions[i] = expression.ColumnSubstituteAll(cond, schema, exprs); hasFail { + return + } } + for i, cond := range cpRightConditions { + if hasFail, cpRightConditions[i] = expression.ColumnSubstituteAll(cond, schema, exprs); hasFail { + return + } + } + + for i, cond := range cpOtherConditions { + if hasFail, cpOtherConditions[i] = expression.ColumnSubstituteAll(cond, schema, exprs); hasFail { + return + } + } + + for i, cond := range cpEqualConditions { + var tmp expression.Expression + if hasFail, tmp = expression.ColumnSubstituteAll(cond, schema, exprs); hasFail { + return + } + cpEqualConditions[i] = tmp.(*expression.ScalarFunction) + } + + // if all substituted, change them atomically here. + p.LeftConditions = cpLeftConditions + p.RightConditions = cpRightConditions + p.OtherConditions = cpOtherConditions + p.EqualConditions = cpEqualConditions + for i := len(p.EqualConditions) - 1; i >= 0; i-- { - newCond := expression.ColumnSubstitute(p.EqualConditions[i], schema, exprs).(*expression.ScalarFunction) + newCond := p.EqualConditions[i] // If the columns used in the new filter all come from the left child, // we can push this filter to it. @@ -229,6 +294,7 @@ func (p *LogicalJoin) columnSubstitute(schema *expression.Schema, exprs []expres p.EqualConditions[i] = newCond } + return false } // AttachOnConds extracts on conditions for join and set the `EqualConditions`, `LeftConditions`, `RightConditions` and diff --git a/planner/core/rule_decorrelate.go b/planner/core/rule_decorrelate.go index 1e97d6ff21454..bdeca74162484 100644 --- a/planner/core/rule_decorrelate.go +++ b/planner/core/rule_decorrelate.go @@ -146,10 +146,48 @@ func (s *decorrelateSolver) optimize(ctx context.Context, p LogicalPlan) (Logica return s.optimize(ctx, p) } } else if proj, ok := innerPlan.(*LogicalProjection); ok { +<<<<<<< HEAD +======= + // After the column pruning, some expressions in the projection operator may be pruned. + // In this situation, we can decorrelate the apply operator. + allConst := len(proj.Exprs) > 0 + for _, expr := range proj.Exprs { + if len(expression.ExtractCorColumns(expr)) > 0 || !expression.ExtractColumnSet(expr).IsEmpty() { + allConst = false + break + } + } + if allConst && apply.JoinType == LeftOuterJoin { + // If the projection just references some constant. We cannot directly pull it up when the APPLY is an outer join. + // e.g. select (select 1 from t1 where t1.a=t2.a) from t2; When the t1.a=t2.a is false the join's output is NULL. + // But if we pull the projection upon the APPLY. It will return 1 since the projection is evaluated after the join. + // We disable the decorrelation directly for now. + // TODO: Actually, it can be optimized. We need to first push the projection down to the selection. And then the APPLY can be decorrelated. + goto NoOptimize + } + + // step1: substitute the all the schema with new expressions (including correlated column maybe, but it doesn't affect the collation infer inside) + // eg: projection: constant("guo") --> column8, once upper layer substitution failed here, the lower layer behind + // projection can't supply column8 anymore. + // + // upper OP (depend on column8) --> projection(constant "guo" --> column8) --> lower layer OP + // | ^ + // +-------------------------------------------------------+ + // + // upper OP (depend on column8) --> lower layer OP + // | ^ + // +-----------------------------+ // Fail: lower layer can't supply column8 anymore. + hasFail := apply.columnSubstituteAll(proj.Schema(), proj.Exprs) + if hasFail { + goto NoOptimize + } + // step2: when it can be substituted all, we then just do the de-correlation (apply conditions included). +>>>>>>> d3483026e... planner: mark the both side operand of NAAJ & refuse partial column substitute in projection elimination of Apply de-correlation (#37117) for i, expr := range proj.Exprs { proj.Exprs[i] = expr.Decorrelate(outerPlan.Schema()) } - apply.columnSubstitute(proj.Schema(), proj.Exprs) + apply.decorrelate(outerPlan.Schema()) + innerPlan = proj.children[0] apply.SetChildren(outerPlan, innerPlan) if apply.JoinType != SemiJoin && apply.JoinType != LeftOuterSemiJoin && apply.JoinType != AntiSemiJoin && apply.JoinType != AntiLeftOuterSemiJoin {