Skip to content

Commit

Permalink
cherry pick pingcap#37117 to release-6.2
Browse files Browse the repository at this point in the history
Signed-off-by: ti-srebot <ti-srebot@pingcap.com>
  • Loading branch information
AilinKid authored and ti-srebot committed Aug 24, 2022
1 parent daf2b17 commit f4a39be
Show file tree
Hide file tree
Showing 7 changed files with 174 additions and 24 deletions.
28 changes: 28 additions & 0 deletions cmd/explaintest/r/subquery.result
Original file line number Diff line number Diff line change
Expand Up @@ -46,3 +46,31 @@ create table t1(a int(11));
create table t2(a decimal(40,20) unsigned, b decimal(40,20));
select count(*) as x from t1 group by a having x not in (select a from t2 where x = t2.b);
x
drop table if exists stu;
drop table if exists exam;
create table stu(id int, name varchar(100));
insert into stu values(1, null);
create table exam(stu_id int, course varchar(100), grade int);
insert into exam values(1, 'math', 100);
set names utf8 collate utf8_general_ci;
explain format = 'brief' select * from stu where stu.name not in (select 'guo' from exam where exam.stu_id = stu.id);
id estRows task access object operator info
Apply 10000.00 root CARTESIAN anti semi join, other cond:eq(test.stu.name, Column#8)
├─TableReader(Build) 10000.00 root data:TableFullScan
│ └─TableFullScan 10000.00 cop[tikv] table:stu keep order:false, stats:pseudo
└─Projection(Probe) 10.00 root guo->Column#8
└─TableReader 10.00 root data:Selection
└─Selection 10.00 cop[tikv] eq(test.exam.stu_id, test.stu.id)
└─TableFullScan 10000.00 cop[tikv] table:exam keep order:false, stats:pseudo
select * from stu where stu.name not in (select 'guo' from exam where exam.stu_id = stu.id);
id name
set names utf8mb4;
explain format = 'brief' select * from stu where stu.name not in (select 'guo' from exam where exam.stu_id = stu.id);
id estRows task access object operator info
HashJoin 8000.00 root anti semi join, equal:[eq(test.stu.id, test.exam.stu_id)], other cond:eq(test.stu.name, "guo")
├─TableReader(Build) 10000.00 root data:TableFullScan
│ └─TableFullScan 10000.00 cop[tikv] table:exam keep order:false, stats:pseudo
└─TableReader(Probe) 10000.00 root data:TableFullScan
└─TableFullScan 10000.00 cop[tikv] table:stu keep order:false, stats:pseudo
select * from stu where stu.name not in (select 'guo' from exam where exam.stu_id = stu.id);
id name
13 changes: 13 additions & 0 deletions cmd/explaintest/t/subquery.test
Original file line number Diff line number Diff line change
Expand Up @@ -20,3 +20,16 @@ drop table if exists t1, t2;
create table t1(a int(11));
create table t2(a decimal(40,20) unsigned, b decimal(40,20));
select count(*) as x from t1 group by a having x not in (select a from t2 where x = t2.b);

drop table if exists stu;
drop table if exists exam;
create table stu(id int, name varchar(100));
insert into stu values(1, null);
create table exam(stu_id int, course varchar(100), grade int);
insert into exam values(1, 'math', 100);
set names utf8 collate utf8_general_ci;
explain format = 'brief' select * from stu where stu.name not in (select 'guo' from exam where exam.stu_id = stu.id);
select * from stu where stu.name not in (select 'guo' from exam where exam.stu_id = stu.id);
set names utf8mb4;
explain format = 'brief' select * from stu where stu.name not in (select 'guo' from exam where exam.stu_id = stu.id);
select * from stu where stu.name not in (select 'guo' from exam where exam.stu_id = stu.id);
56 changes: 43 additions & 13 deletions expression/util.go
Original file line number Diff line number Diff line change
Expand Up @@ -365,7 +365,8 @@ func extractColumnSet(expr Expression, set *intsets.Sparse) {
}
}

func setExprColumnInOperand(expr Expression) Expression {
// SetExprColumnInOperand is used to set columns in expr as InOperand.
func SetExprColumnInOperand(expr Expression) Expression {
switch v := expr.(type) {
case *Column:
col := v.Clone().(*Column)
Expand All @@ -374,7 +375,7 @@ func setExprColumnInOperand(expr Expression) Expression {
case *ScalarFunction:
args := v.GetArgs()
for i, arg := range args {
args[i] = setExprColumnInOperand(arg)
args[i] = SetExprColumnInOperand(arg)
}
}
return expr
Expand All @@ -383,44 +384,65 @@ func setExprColumnInOperand(expr Expression) Expression {
// ColumnSubstitute substitutes the columns in filter to expressions in select fields.
// e.g. select * from (select b as a from t) k where a < 10 => select * from (select b as a from t where b < 10) k.
func ColumnSubstitute(expr Expression, schema *Schema, newExprs []Expression) Expression {
_, resExpr := ColumnSubstituteImpl(expr, schema, newExprs)
_, _, resExpr := ColumnSubstituteImpl(expr, schema, newExprs, false)
return resExpr
}

// ColumnSubstituteAll substitutes the columns just like ColumnSubstitute, but we don't accept partial substitution.
// Only accept:
//
// 1: substitute them all once find col in schema.
// 2: nothing in expr can be substituted.
func ColumnSubstituteAll(expr Expression, schema *Schema, newExprs []Expression) (bool, Expression) {
_, hasFail, resExpr := ColumnSubstituteImpl(expr, schema, newExprs, true)
return hasFail, resExpr
}

// ColumnSubstituteImpl tries to substitute column expr using newExprs,
// the newFunctionInternal is only called if its child is substituted
func ColumnSubstituteImpl(expr Expression, schema *Schema, newExprs []Expression) (bool, Expression) {
// @return bool means whether the expr has changed.
// @return bool means whether the expr should change (has the dependency in schema, while the corresponding expr has some compatibility), but finally fallback.
// @return Expression, the original expr or the changed expr, it depends on the first @return bool.
func ColumnSubstituteImpl(expr Expression, schema *Schema, newExprs []Expression, fail1Return bool) (bool, bool, Expression) {
switch v := expr.(type) {
case *Column:
id := schema.ColumnIndex(v)
if id == -1 {
return false, v
return false, false, v
}
newExpr := newExprs[id]
if v.InOperand {
newExpr = setExprColumnInOperand(newExpr)
newExpr = SetExprColumnInOperand(newExpr)
}
newExpr.SetCoercibility(v.Coercibility())
return true, newExpr
return true, false, newExpr
case *ScalarFunction:
substituted := false
hasFail := false
if v.FuncName.L == ast.Cast {
newFunc := v.Clone().(*ScalarFunction)
substituted, newFunc.GetArgs()[0] = ColumnSubstituteImpl(newFunc.GetArgs()[0], schema, newExprs)
substituted, hasFail, newFunc.GetArgs()[0] = ColumnSubstituteImpl(newFunc.GetArgs()[0], schema, newExprs, fail1Return)
if fail1Return && hasFail {
return substituted, hasFail, newFunc
}
if substituted {
// Workaround for issue https://github.com/pingcap/tidb/issues/28804
e := NewFunctionInternal(v.GetCtx(), v.FuncName.L, v.RetType, newFunc.GetArgs()...)
e.SetCoercibility(v.Coercibility())
return true, e
return true, false, e
}
return false, newFunc
return false, false, newFunc
}
// cowExprRef is a copy-on-write util, args array allocation happens only
// when expr in args is changed
refExprArr := cowExprRef{v.GetArgs(), nil}
_, coll := DeriveCollationFromExprs(v.GetCtx(), v.GetArgs()...)
for idx, arg := range v.GetArgs() {
changed, newFuncExpr := ColumnSubstituteImpl(arg, schema, newExprs)
changed, hasFail, newFuncExpr := ColumnSubstituteImpl(arg, schema, newExprs, fail1Return)
if fail1Return && hasFail {
return changed, hasFail, v
}
oldChanged := changed
if collate.NewCollationEnabled() {
// Make sure the collation used by the ScalarFunction isn't changed and its result collation is not weaker than the collation used by the ScalarFunction.
if changed {
Expand All @@ -433,16 +455,24 @@ func ColumnSubstituteImpl(expr Expression, schema *Schema, newExprs []Expression
}
}
}
if fail1Return && oldChanged != changed {
// Only when the oldChanged is true and changed is false, we will get here.
// And this means there some dependency in this arg can be substituted with
// given expressions, while it has some collation compatibility, finally we
// fall back to use the origin args. (commonly used in projection elimination
// in which fallback usage is unacceptable)
return changed, true, v
}
refExprArr.Set(idx, changed, newFuncExpr)
if changed {
substituted = true
}
}
if substituted {
return true, NewFunctionInternal(v.GetCtx(), v.FuncName.L, v.RetType, refExprArr.Result()...)
return true, false, NewFunctionInternal(v.GetCtx(), v.FuncName.L, v.RetType, refExprArr.Result()...)
}
}
return false, expr
return false, false, expr
}

// checkCollationStrictness check collation strictness-ship between `coll` and `newFuncColl`
Expand Down
4 changes: 2 additions & 2 deletions expression/util_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -193,12 +193,12 @@ func TestGetUint64FromConstant(t *testing.T) {

func TestSetExprColumnInOperand(t *testing.T) {
col := &Column{RetType: newIntFieldType()}
require.True(t, setExprColumnInOperand(col).(*Column).InOperand)
require.True(t, SetExprColumnInOperand(col).(*Column).InOperand)

f, err := funcs[ast.Abs].getFunction(mock.NewContext(), []Expression{col})
require.NoError(t, err)
fun := &ScalarFunction{Function: f}
setExprColumnInOperand(fun)
SetExprColumnInOperand(fun)
require.True(t, f.getArgs()[0].(*Column).InOperand)
}

Expand Down
13 changes: 12 additions & 1 deletion planner/core/expression_rewriter.go
Original file line number Diff line number Diff line change
Expand Up @@ -479,6 +479,7 @@ func (er *expressionRewriter) buildSemiApplyFromEqualSubq(np LogicalPlan, l, r e
rColCopy := *rCol
rColCopy.InOperand = true
r = &rColCopy
l = expression.SetExprColumnInOperand(l)
}
} else {
rowFunc := r.(*expression.ScalarFunction)
Expand All @@ -501,6 +502,7 @@ func (er *expressionRewriter) buildSemiApplyFromEqualSubq(np LogicalPlan, l, r e
if er.err != nil {
return
}
l = expression.SetExprColumnInOperand(l)
}
}
}
Expand Down Expand Up @@ -912,22 +914,31 @@ func (er *expressionRewriter) handleInSubquery(ctx context.Context, v *ast.Patte
// normal column equal condition, so we specially mark the inner operand here.
if v.Not || asScalar {
// If both input columns of `in` expression are not null, we can treat the expression
// as normal column equal condition instead.
// as normal column equal condition instead. Otherwise, mark the left and right side.
// eg: for some optimization, the column substitute in right side in projection elimination
// will cause case like <lcol EQ rcol(inOperand)> as <lcol EQ constant> which is not
// a valid null-aware EQ. (null in lcol still need to be null-aware)
if !expression.ExprNotNull(lexpr) || !expression.ExprNotNull(rCol) {
rColCopy := *rCol
rColCopy.InOperand = true
rexpr = &rColCopy
lexpr = expression.SetExprColumnInOperand(lexpr)
}
}
} else {
args := make([]expression.Expression, 0, np.Schema().Len())
for i, col := range np.Schema().Columns {
if v.Not || asScalar {
larg := expression.GetFuncArg(lexpr, i)
// If both input columns of `in` expression are not null, we can treat the expression
// as normal column equal condition instead. Otherwise, mark the left and right side.
if !expression.ExprNotNull(larg) || !expression.ExprNotNull(col) {
rarg := *col
rarg.InOperand = true
col = &rarg
if larg != nil {
lexpr.(*expression.ScalarFunction).GetArgs()[i] = expression.SetExprColumnInOperand(larg)
}
}
}
args = append(args, col)
Expand Down
64 changes: 57 additions & 7 deletions planner/core/logical_plans.go
Original file line number Diff line number Diff line change
Expand Up @@ -375,21 +375,70 @@ func (p *LogicalJoin) GetPotentialPartitionKeys() (leftKeys, rightKeys []*proper
return
}

func (p *LogicalJoin) columnSubstitute(schema *expression.Schema, exprs []expression.Expression) {
// decorrelate eliminate the correlated column with if the col is in schema.
func (p *LogicalJoin) decorrelate(schema *expression.Schema) {
for i, cond := range p.LeftConditions {
p.LeftConditions[i] = expression.ColumnSubstitute(cond, schema, exprs)
p.LeftConditions[i] = cond.Decorrelate(schema)
}

for i, cond := range p.RightConditions {
p.RightConditions[i] = expression.ColumnSubstitute(cond, schema, exprs)
p.RightConditions[i] = cond.Decorrelate(schema)
}

for i, cond := range p.OtherConditions {
p.OtherConditions[i] = expression.ColumnSubstitute(cond, schema, exprs)
p.OtherConditions[i] = cond.Decorrelate(schema)
}
for i, cond := range p.EqualConditions {
p.EqualConditions[i] = cond.Decorrelate(schema).(*expression.ScalarFunction)
}
}

// columnSubstituteAll is used in projection elimination in apply de-correlation.
// Substitutions for all conditions should be successful, otherwise, we should keep all conditions unchanged.
func (p *LogicalJoin) columnSubstituteAll(schema *expression.Schema, exprs []expression.Expression) (hasFail bool) {
// make a copy of exprs for convenience of substitution (may change/partially change the expr tree)
cpLeftConditions := make(expression.CNFExprs, len(p.LeftConditions))
cpRightConditions := make(expression.CNFExprs, len(p.RightConditions))
cpOtherConditions := make(expression.CNFExprs, len(p.OtherConditions))
cpEqualConditions := make([]*expression.ScalarFunction, len(p.EqualConditions))
copy(cpLeftConditions, p.LeftConditions)
copy(cpRightConditions, p.RightConditions)
copy(cpOtherConditions, p.OtherConditions)
copy(cpEqualConditions, p.EqualConditions)

// try to substitute columns in these condition.
for i, cond := range cpLeftConditions {
if hasFail, cpLeftConditions[i] = expression.ColumnSubstituteAll(cond, schema, exprs); hasFail {
return
}
}

for i, cond := range cpRightConditions {
if hasFail, cpRightConditions[i] = expression.ColumnSubstituteAll(cond, schema, exprs); hasFail {
return
}
}

for i, cond := range cpOtherConditions {
if hasFail, cpOtherConditions[i] = expression.ColumnSubstituteAll(cond, schema, exprs); hasFail {
return
}
}

for i, cond := range cpEqualConditions {
var tmp expression.Expression
if hasFail, tmp = expression.ColumnSubstituteAll(cond, schema, exprs); hasFail {
return
}
cpEqualConditions[i] = tmp.(*expression.ScalarFunction)
}

// if all substituted, change them atomically here.
p.LeftConditions = cpLeftConditions
p.RightConditions = cpRightConditions
p.OtherConditions = cpOtherConditions
p.EqualConditions = cpEqualConditions

for i := len(p.EqualConditions) - 1; i >= 0; i-- {
newCond := expression.ColumnSubstitute(p.EqualConditions[i], schema, exprs).(*expression.ScalarFunction)
newCond := p.EqualConditions[i]

// If the columns used in the new filter all come from the left child,
// we can push this filter to it.
Expand Down Expand Up @@ -420,6 +469,7 @@ func (p *LogicalJoin) columnSubstitute(schema *expression.Schema, exprs []expres

p.EqualConditions[i] = newCond
}
return false
}

// AttachOnConds extracts on conditions for join and set the `EqualConditions`, `LeftConditions`, `RightConditions` and
Expand Down
20 changes: 19 additions & 1 deletion planner/core/rule_decorrelate.go
Original file line number Diff line number Diff line change
Expand Up @@ -171,10 +171,28 @@ func (s *decorrelateSolver) optimize(ctx context.Context, p LogicalPlan, opt *lo
// TODO: Actually, it can be optimized. We need to first push the projection down to the selection. And then the APPLY can be decorrelated.
goto NoOptimize
}

// step1: substitute the all the schema with new expressions (including correlated column maybe, but it doesn't affect the collation infer inside)
// eg: projection: constant("guo") --> column8, once upper layer substitution failed here, the lower layer behind
// projection can't supply column8 anymore.
//
// upper OP (depend on column8) --> projection(constant "guo" --> column8) --> lower layer OP
// | ^
// +-------------------------------------------------------+
//
// upper OP (depend on column8) --> lower layer OP
// | ^
// +-----------------------------+ // Fail: lower layer can't supply column8 anymore.
hasFail := apply.columnSubstituteAll(proj.Schema(), proj.Exprs)
if hasFail {
goto NoOptimize
}
// step2: when it can be substituted all, we then just do the de-correlation (apply conditions included).
for i, expr := range proj.Exprs {
proj.Exprs[i] = expr.Decorrelate(outerPlan.Schema())
}
apply.columnSubstitute(proj.Schema(), proj.Exprs)
apply.decorrelate(outerPlan.Schema())

innerPlan = proj.children[0]
apply.SetChildren(outerPlan, innerPlan)
if apply.JoinType != SemiJoin && apply.JoinType != LeftOuterSemiJoin && apply.JoinType != AntiSemiJoin && apply.JoinType != AntiLeftOuterSemiJoin {
Expand Down

0 comments on commit f4a39be

Please sign in to comment.