From adc3e5b1f0a17d82f46d1e5d093081662bad905b Mon Sep 17 00:00:00 2001 From: Zhang Jian Date: Mon, 11 Feb 2019 15:01:36 +0800 Subject: [PATCH] planner: fix assertion failure on LogicalJoin.EqualConditions (#9066) (#9265) --- executor/join_test.go | 15 ++++++++++++ expression/schema.go | 16 ++++++++++-- expression/util_test.go | 18 ++++++++++++++ plan/logical_plans.go | 54 +++++++++++++++++++++++++++++------------ 4 files changed, 86 insertions(+), 17 deletions(-) diff --git a/executor/join_test.go b/executor/join_test.go index eac2579af845a..abaa333b6546d 100644 --- a/executor/join_test.go +++ b/executor/join_test.go @@ -959,3 +959,18 @@ func (s *testSuite) TestJoinDifferentDecimals(c *C) { c.Assert(len(row), Equals, 3) rst.Check(testkit.Rows("1 1.000", "2 2.000", "3 3.000")) } + +func (s *testSuite) TestSubquery2Join(c *C) { + tk := testkit.NewTestKit(c, s.store) + tk.MustExec("Use test") + tk.MustExec(`drop table if exists t1;`) + tk.MustExec(`drop table if exists t2;`) + tk.MustExec(`create table t1(a bigint, b bigint);`) + tk.MustExec(`create table t2(a bigint, b bigint);`) + tk.MustExec(`insert into t1 values(1, 1);`) + tk.MustExec(`insert into t1 values(2, 1);`) + tk.MustExec(`insert into t2 values(1, 1);`) + tk.MustQuery(`select * from t1 where t1.a in (select t1.b + t2.b from t2);`).Check(testkit.Rows( + `2 1`, + )) +} diff --git a/expression/schema.go b/expression/schema.go index 18b0fa7169eca..7113a599d7cbb 100644 --- a/expression/schema.go +++ b/expression/schema.go @@ -90,8 +90,20 @@ func (s *Schema) Clone() *Schema { // ExprFromSchema checks if all columns of this expression are from the same schema. func ExprFromSchema(expr Expression, schema *Schema) bool { - cols := ExtractColumns(expr) - return len(schema.ColumnsIndices(cols)) > 0 + switch v := expr.(type) { + case *Column: + return schema.Contains(v) + case *ScalarFunction: + for _, arg := range v.GetArgs() { + if !ExprFromSchema(arg, schema) { + return false + } + } + return true + case *CorrelatedColumn, *Constant: + return true + } + return false } // FindColumn finds an Column from schema for a ast.ColumnName. It compares the db/table/column names. diff --git a/expression/util_test.go b/expression/util_test.go index 2fdbc067dff9e..d7339b3c72bc9 100644 --- a/expression/util_test.go +++ b/expression/util_test.go @@ -106,3 +106,21 @@ func BenchmarkExtractColumns(b *testing.B) { } b.ReportAllocs() } + +func BenchmarkExprFromSchema(b *testing.B) { + conditions := []Expression{ + newFunction(ast.EQ, newColumn(0), newColumn(1)), + newFunction(ast.EQ, newColumn(1), newColumn(2)), + newFunction(ast.EQ, newColumn(2), newColumn(3)), + newFunction(ast.EQ, newColumn(3), newLonglong(1)), + newFunction(ast.LogicOr, newLonglong(1), newColumn(0)), + } + expr := ComposeCNFCondition(mock.NewContext(), conditions...) + schema := &Schema{Columns: ExtractColumns(expr)} + + b.ResetTimer() + for i := 0; i < b.N; i++ { + ExprFromSchema(expr, schema) + } + b.ReportAllocs() +} diff --git a/plan/logical_plans.go b/plan/logical_plans.go index 853dba694a75f..2e8d99b3a7c79 100644 --- a/plan/logical_plans.go +++ b/plan/logical_plans.go @@ -123,25 +123,49 @@ type LogicalJoin struct { } func (p *LogicalJoin) columnSubstitute(schema *expression.Schema, exprs []expression.Expression) { + for i, cond := range p.LeftConditions { + p.LeftConditions[i] = expression.ColumnSubstitute(cond, schema, exprs) + } + + for i, cond := range p.RightConditions { + p.RightConditions[i] = expression.ColumnSubstitute(cond, schema, exprs) + } + + for i, cond := range p.OtherConditions { + p.OtherConditions[i] = expression.ColumnSubstitute(cond, schema, exprs) + } + for i := len(p.EqualConditions) - 1; i >= 0; i-- { - p.EqualConditions[i] = expression.ColumnSubstitute(p.EqualConditions[i], schema, exprs).(*expression.ScalarFunction) - // After the column substitute, the equal condition may become single side condition. - if p.children[0].Schema().Contains(p.EqualConditions[i].GetArgs()[1].(*expression.Column)) { - p.LeftConditions = append(p.LeftConditions, p.EqualConditions[i]) + newCond := expression.ColumnSubstitute(p.EqualConditions[i], schema, exprs).(*expression.ScalarFunction) + + // If the columns used in the new filter all come from the left child, + // we can push this filter to it. + if expression.ExprFromSchema(newCond, p.children[0].Schema()) { + p.LeftConditions = append(p.LeftConditions, newCond) p.EqualConditions = append(p.EqualConditions[:i], p.EqualConditions[i+1:]...) - } else if p.children[1].Schema().Contains(p.EqualConditions[i].GetArgs()[0].(*expression.Column)) { - p.RightConditions = append(p.RightConditions, p.EqualConditions[i]) + continue + } + + // If the columns used in the new filter all come from the right + // child, we can push this filter to it. + if expression.ExprFromSchema(newCond, p.children[1].Schema()) { + p.RightConditions = append(p.RightConditions, newCond) p.EqualConditions = append(p.EqualConditions[:i], p.EqualConditions[i+1:]...) + continue } - } - for i, fun := range p.LeftConditions { - p.LeftConditions[i] = expression.ColumnSubstitute(fun, schema, exprs) - } - for i, fun := range p.RightConditions { - p.RightConditions[i] = expression.ColumnSubstitute(fun, schema, exprs) - } - for i, fun := range p.OtherConditions { - p.OtherConditions[i] = expression.ColumnSubstitute(fun, schema, exprs) + + _, lhsIsCol := newCond.GetArgs()[0].(*expression.Column) + _, rhsIsCol := newCond.GetArgs()[1].(*expression.Column) + + // If the columns used in the new filter are not all expression.Column, + // we can not use it as join's equal condition. + if !(lhsIsCol && rhsIsCol) { + p.OtherConditions = append(p.OtherConditions, newCond) + p.EqualConditions = append(p.EqualConditions[:i], p.EqualConditions[i+1:]...) + continue + } + + p.EqualConditions[i] = newCond } }