From df1937153b1615cc89d572dd5429c2f63d9b470e Mon Sep 17 00:00:00 2001 From: Jiaan Geng Date: Tue, 10 May 2022 17:37:23 +0800 Subject: [PATCH] [SPARK-39135][SQL] DS V2 aggregate partial push-down should supports group by without aggregate functions ### What changes were proposed in this pull request? Currently, the SQL show below not supported by DS V2 aggregate partial push-down. `select key from tab group by key` ### Why are the changes needed? Make DS V2 aggregate partial push-down supports group by without aggregate functions. ### Does this PR introduce _any_ user-facing change? 'No'. New feature. ### How was this patch tested? New tests Closes #36492 from beliefer/SPARK-39135. Authored-by: Jiaan Geng Signed-off-by: Wenchen Fan --- .../v2/V2ScanRelationPushDown.scala | 2 +- .../apache/spark/sql/jdbc/JDBCV2Suite.scala | 51 +++++++++++++++++++ 2 files changed, 52 insertions(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala index 64bb6b2834ea1..6291b3c8f953e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala @@ -286,7 +286,7 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper wit private def supportPartialAggPushDown(agg: Aggregation): Boolean = { // We don't know the agg buffer of `GeneralAggregateFunc`, so can't do partial agg push down. // If `Sum`, `Count`, `Avg` with distinct, can't do partial agg push down. - agg.aggregateExpressions().exists { + agg.aggregateExpressions().isEmpty || agg.aggregateExpressions().exists { case sum: Sum => !sum.isDistinct case count: Count => !count.isDistinct case avg: Avg => !avg.isDistinct diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala index 5f730745b775f..b111a71903ffd 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala @@ -670,6 +670,57 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel checkAnswer(df, Seq(Row(5))) } + test("scan with aggregate push-down: GROUP BY without aggregate functions") { + val df = sql("select name FROM h2.test.employee GROUP BY name") + checkAggregateRemoved(df) + checkPushedInfo(df, + "PushedAggregates: [], PushedFilters: [], PushedGroupByExpressions: [NAME],") + checkAnswer(df, Seq(Row("alex"), Row("amy"), Row("cathy"), Row("david"), Row("jen"))) + + val df2 = spark.read + .option("partitionColumn", "dept") + .option("lowerBound", "0") + .option("upperBound", "2") + .option("numPartitions", "2") + .table("h2.test.employee") + .groupBy($"name") + .agg(Map.empty[String, String]) + checkAggregateRemoved(df2, false) + checkPushedInfo(df2, + "PushedAggregates: [], PushedFilters: [], PushedGroupByExpressions: [NAME],") + checkAnswer(df2, Seq(Row("alex"), Row("amy"), Row("cathy"), Row("david"), Row("jen"))) + + val df3 = sql("SELECT CASE WHEN SALARY > 8000 AND SALARY < 10000 THEN SALARY ELSE 0 END as" + + " key FROM h2.test.employee GROUP BY key") + checkAggregateRemoved(df3) + checkPushedInfo(df3, + """ + |PushedAggregates: [], + |PushedFilters: [], + |PushedGroupByExpressions: + |[CASE WHEN (SALARY > 8000.00) AND (SALARY < 10000.00) THEN SALARY ELSE 0.00 END], + |""".stripMargin.replaceAll("\n", " ")) + checkAnswer(df3, Seq(Row(0), Row(9000))) + + val df4 = spark.read + .option("partitionColumn", "dept") + .option("lowerBound", "0") + .option("upperBound", "2") + .option("numPartitions", "2") + .table("h2.test.employee") + .groupBy(when(($"SALARY" > 8000).and($"SALARY" < 10000), $"SALARY").otherwise(0).as("key")) + .agg(Map.empty[String, String]) + checkAggregateRemoved(df4, false) + checkPushedInfo(df4, + """ + |PushedAggregates: [], + |PushedFilters: [], + |PushedGroupByExpressions: + |[CASE WHEN (SALARY > 8000.00) AND (SALARY < 10000.00) THEN SALARY ELSE 0.00 END], + |""".stripMargin.replaceAll("\n", " ")) + checkAnswer(df4, Seq(Row(0), Row(9000))) + } + test("scan with aggregate push-down: COUNT(col)") { val df = sql("select COUNT(DEPT) FROM h2.test.employee") checkAggregateRemoved(df)