From c3ec8488d945712d6e2a19987e205f5aee5146de Mon Sep 17 00:00:00 2001 From: angerszhu Date: Fri, 6 Nov 2020 17:29:30 +0800 Subject: [PATCH] [SPARK-33302][SQL] Failed to push down filters through Expand --- .../sql/catalyst/optimizer/Optimizer.scala | 1 + .../optimizer/FilterPushdownSuite.scala | 32 ++++++++++++++++++- 2 files changed, 32 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 9519a56c2817a..51f7799b1e427 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -1269,6 +1269,7 @@ object PushPredicateThroughNonJoin extends Rule[LogicalPlan] with PredicateHelpe case _: Sort => true case _: BatchEvalPython => true case _: ArrowEvalPython => true + case _: Expand => true case _ => false } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala index 156313300eef9..5e151c7b8ffdb 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala @@ -26,7 +26,7 @@ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules._ import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.types.{BooleanType, IntegerType, TimestampType} +import org.apache.spark.sql.types.{BooleanType, IntegerType, StringType, TimestampType} import org.apache.spark.unsafe.types.CalendarInterval class FilterPushdownSuite extends PlanTest { @@ -1208,6 +1208,36 @@ class FilterPushdownSuite extends PlanTest { checkAnalysis = false) } + + test("push down predicate through expand") { + val input = LocalRelation('a.int, 'b.string, 'c.double) + val query = + Aggregate( + Seq('a, 'b), + Seq(sum('c).as("sum")), + Filter('a > 1, + Expand( + Seq( + Seq('a, 'b, 'c, Literal.create(null, StringType), 1), + Seq('a, 'b, 'c, 'a, 2)), + Seq('a, 'b, 'c), + input))).analyze + val optimized = Optimize.execute(query) + + val expected = + Aggregate( + Seq('a, 'b), + Seq(sum('c).as("sum")), + Expand( + Seq( + Seq('a, 'b, 'c, Literal.create(null, StringType), 1), + Seq('a, 'b, 'c, 'a, 2)), + Seq('a, 'b, 'c), + Filter('a > 1, input))).analyze + + comparePlans(optimized, expected) + } + test("SPARK-28345: PythonUDF predicate should be able to pushdown to join") { val pythonUDFJoinCond = { val pythonUDF = PythonUDF("pythonUDF", null,