From e6546bebe1a2aac0a4aeb624a5e64790343d0069 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Fri, 7 Apr 2017 14:30:41 +0800 Subject: [PATCH] should not push predicate down through aggregate with non-deterministic expressions --- .../sql/catalyst/optimizer/Optimizer.scala | 8 ++-- .../optimizer/FilterPushdownSuite.scala | 41 +++++++++++++++++-- 2 files changed, 42 insertions(+), 7 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 577112779eea4..61edb6e206265 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -755,7 +755,8 @@ object PushDownPredicate extends Rule[LogicalPlan] with PredicateHelper { // implies that, for a given input row, the output are determined by the expression's initial // state and all the input rows processed before. In another word, the order of input rows // matters for non-deterministic expressions, while pushing down predicates changes the order. - case filter @ Filter(condition, project @ Project(fields, grandChild)) + // This also applies to Aggregate. + case Filter(condition, project @ Project(fields, grandChild)) if fields.forall(_.deterministic) && canPushThroughCondition(grandChild, condition) => // Create a map of Aliases to their values from the child projection. @@ -792,7 +793,8 @@ object PushDownPredicate extends Rule[LogicalPlan] with PredicateHelper { filter } - case filter @ Filter(condition, aggregate: Aggregate) => + case filter @ Filter(condition, aggregate: Aggregate) + if aggregate.aggregateExpressions.forall(_.deterministic) => // Find all the aliased expressions in the aggregate list that don't include any actual // AggregateExpression, and create a map from the alias to the expression val aliasMap = AttributeMap(aggregate.aggregateExpressions.collect { @@ -848,7 +850,7 @@ object PushDownPredicate extends Rule[LogicalPlan] with PredicateHelper { filter } - case filter @ Filter(condition, u: UnaryNode) + case filter @ Filter(_, u: UnaryNode) if canPushThrough(u) && u.expressions.forall(_.deterministic) => pushDownPredicate(filter, u.child) { predicate => u.withNewChildren(Seq(Filter(predicate, u.child))) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala index d846786473eb0..ccd0b7c5d7f79 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala @@ -134,15 +134,20 @@ class FilterPushdownSuite extends PlanTest { comparePlans(optimized, correctAnswer) } - test("nondeterministic: can't push down filter with nondeterministic condition through project") { + test("nondeterministic: can always push down filter through project with deterministic field") { val originalQuery = testRelation - .select(Rand(10).as('rand), 'a) - .where('rand > 5 || 'a > 5) + .select('a) + .where(Rand(10) > 5 || 'a > 5) .analyze val optimized = Optimize.execute(originalQuery) - comparePlans(optimized, originalQuery) + val correctAnswer = testRelation + .where(Rand(10) > 5 || 'a > 5) + .select('a) + .analyze + + comparePlans(optimized, correctAnswer) } test("nondeterministic: can't push down filter through project with nondeterministic field") { @@ -156,6 +161,34 @@ class FilterPushdownSuite extends PlanTest { comparePlans(optimized, originalQuery) } + test("nondeterministic: can't push down filter through aggregate with nondeterministic field") { + val originalQuery = testRelation + .groupBy('a)('a, Rand(10).as('rand)) + .where('a > 5) + .analyze + + val optimized = Optimize.execute(originalQuery) + + comparePlans(optimized, originalQuery) + } + + test("nondeterministic: push down part of filter through aggregate with deterministic field") { + val originalQuery = testRelation + .groupBy('a)('a) + .where('a > 5 && Rand(10) > 5) + .analyze + + val optimized = Optimize.execute(originalQuery.analyze) + + val correctAnswer = testRelation + .where('a > 5) + .groupBy('a)('a) + .where(Rand(10) > 5) + .analyze + + comparePlans(optimized, correctAnswer) + } + test("filters: combines filters") { val originalQuery = testRelation .select('a)