diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 9519a56c2817a..51f7799b1e427 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -1269,6 +1269,7 @@ object PushPredicateThroughNonJoin extends Rule[LogicalPlan] with PredicateHelpe case _: Sort => true case _: BatchEvalPython => true case _: ArrowEvalPython => true + case _: Expand => true case _ => false } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala index 156313300eef9..11ec037c94f73 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala @@ -26,7 +26,7 @@ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules._ import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.types.{BooleanType, IntegerType, TimestampType} +import org.apache.spark.sql.types.{BooleanType, IntegerType, StringType, TimestampType} import org.apache.spark.unsafe.types.CalendarInterval class FilterPushdownSuite extends PlanTest { @@ -1208,6 +1208,28 @@ class FilterPushdownSuite extends PlanTest { checkAnalysis = false) } + test("push down predicate through expand") { + val query = + Filter('a > 1, + Expand( + Seq( + Seq('a, 'b, 'c, Literal.create(null, StringType), 1), + Seq('a, 'b, 'c, 'a, 2)), + Seq('a, 'b, 'c), + testRelation)).analyze + val optimized = Optimize.execute(query) + + val expected = + Expand( + Seq( + Seq('a, 'b, 'c, Literal.create(null, StringType), 1), + Seq('a, 'b, 'c, 'a, 2)), + Seq('a, 'b, 'c), + Filter('a > 1, testRelation)).analyze + + comparePlans(optimized, expected) + } + test("SPARK-28345: PythonUDF predicate should be able to pushdown to join") { val pythonUDFJoinCond = { val pythonUDF = PythonUDF("pythonUDF", null, diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/LeftSemiAntiJoinPushDownSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/LeftSemiAntiJoinPushDownSuite.scala index a3da9f73ebd40..729a1e9f06ca5 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/LeftSemiAntiJoinPushDownSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/LeftSemiAntiJoinPushDownSuite.scala @@ -315,6 +315,21 @@ class LeftSemiPushdownSuite extends PlanTest { comparePlans(optimized, originalQuery.analyze) } + test("Unary: LeftSemi join push down through Expand") { + val expand = Expand(Seq(Seq('a, 'b, "null"), Seq('a, "null", 'c)), + Seq('a, 'b, 'c), testRelation) + val originalQuery = expand + .join(testRelation1, joinType = LeftSemi, condition = Some('b === 'd && 'b === 1)) + + val optimized = Optimize.execute(originalQuery.analyze) + val correctAnswer = Expand(Seq(Seq('a, 'b, "null"), Seq('a, "null", 'c)), + Seq('a, 'b, 'c), testRelation + .join(testRelation1, joinType = LeftSemi, condition = Some('b === 'd && 'b === 1))) + .analyze + + comparePlans(optimized, correctAnswer) + } + Seq(Some('d === 'e), None).foreach { case innerJoinCond => Seq(LeftSemi, LeftAnti).foreach { case outerJT => Seq(Inner, LeftOuter, Cross, RightOuter).foreach { case innerJT =>