From a5f52a8c7449bc65c35deb06ddb2b9f4bd059104 Mon Sep 17 00:00:00 2001 From: Yuming Wang Date: Fri, 5 Jun 2020 14:40:20 +0800 Subject: [PATCH] Fix --- .../sql/catalyst/optimizer/Optimizer.scala | 5 +++-- .../plans/logical/QueryPlanConstraints.scala | 19 ++++++++++++------- .../InferFiltersFromConstraintsSuite.scala | 9 ++++----- 3 files changed, 19 insertions(+), 14 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index f1a307b1c2cc1..751055b9bf253 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -886,9 +886,10 @@ object InferFiltersFromConstraints extends Rule[LogicalPlan] left: LogicalPlan, right: LogicalPlan, conditionOpt: Option[Expression]): Set[Expression] = { - val baseConstraints = left.constraints.union(right.constraints) - .union(conditionOpt.map(splitConjunctivePredicates).getOrElse(Nil).toSet) + val condition = conditionOpt.map(splitConjunctivePredicates).getOrElse(Nil).toSet + val baseConstraints = left.constraints.union(right.constraints).union(condition) baseConstraints.union(inferAdditionalConstraints(baseConstraints)) + .union(inferIsNotNullConstraintsForJoinCondition(condition)) } private def inferNewFilter(plan: LogicalPlan, constraints: Set[Expression]): LogicalPlan = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/QueryPlanConstraints.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/QueryPlanConstraints.scala index 2716f0bf68b6d..bf2f1940af798 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/QueryPlanConstraints.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/QueryPlanConstraints.scala @@ -104,6 +104,18 @@ trait ConstraintHelper { isNotNullConstraints -- constraints } + /** + * Infers a set of `isNotNull` constraints for non null intolerant child from null intolerant + * expressions. For e.g., if an expression is of the form (`coalesce(t1.a, t1.b) = t2.a`), + * this returns a constraint of the form `isNotNull(coalesce(t1.a, t1.b))` + */ + def inferIsNotNullConstraintsForJoinCondition(constraints: Set[Expression]): Set[Expression] = { + constraints.filter(_.isInstanceOf[NullIntolerant]) + .flatMap { e => + e.children.filter(_.references.nonEmpty).filter(c => inferIsNotNullConstraints(c).isEmpty) + }.map(IsNotNull) + } + /** * Infer the Attribute-specific IsNotNull constraints from the null intolerant child expressions * of constraints. @@ -113,13 +125,6 @@ trait ConstraintHelper { // When the root is IsNotNull, we can push IsNotNull through the child null intolerant // expressions case IsNotNull(expr) => scanNullIntolerantAttribute(expr).map(IsNotNull(_)) - // For join condition: CAST(coalesce(t1.a, t1.b) as DECIMAL) = CAST(t2.c AS DECIMAL). - // We can infer an additional constraint: CAST(coalesce(t1.a, t1.b) as DECIMAL) IS NOT NULL - // to avoid data skew. - case e: BinaryComparison if e.isInstanceOf[NullIntolerant] => - e.children.filter(_.references.nonEmpty).flatMap { c => - Option(scanNullIntolerantAttribute(c)).filter(_.nonEmpty).getOrElse(Seq(c)) - }.map(IsNotNull(_)) // Constraints always return true for all the inputs. That means, null will never be returned. // Thus, we can infer `IsNotNull(constraint)`, and also push IsNotNull through the child // null intolerant expressions. diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromConstraintsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromConstraintsSuite.scala index bc2c66454d2ad..2f9daf2855a69 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromConstraintsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromConstraintsSuite.scala @@ -317,7 +317,7 @@ class InferFiltersFromConstraintsSuite extends PlanTest { } } - test("Infer IsNotNull for all children of binary comparison children") { + test("Infer IsNotNull for non null-intolerant child of null intolerant join condition") { testConstraintsAfterJoin( testRelation.subquery('left), testRelation.subquery('right), @@ -327,9 +327,8 @@ class InferFiltersFromConstraintsSuite extends PlanTest { Some(Coalesce(Seq("left.a".attr, "left.b".attr)) === "right.c".attr)) } - test("Should not infer IsNotNull for non-binary comparison children") { - val query = testRelation.where(Not('b.in(ListQuery(testRelation.select('a))))).analyze - val optimized = Optimize.execute(query) - comparePlans(optimized, query) + test("Should not infer IsNotNull for non null-intolerant child from same table") { + comparePlans(Optimize.execute(testRelation.where(Coalesce(Seq('a, 'b)) === 'c).analyze), + testRelation.where(Coalesce(Seq('a, 'b)) === 'c && IsNotNull('c)).analyze) } }