Skip to content

Commit

Permalink
[SPARK-37527][SQL] Translate more standard aggregate functions for pu…
Browse files Browse the repository at this point in the history
…shdown

### What changes were proposed in this pull request?
Currently, Spark aggregate pushdown will translate some standard aggregate functions, so that compile these functions to adapt specify database.
After this job, users could override `JdbcDialect.compileAggregate` to implement some standard aggregate functions supported by some database.
This PR just translate the ANSI standard aggregate functions. The mainstream database supports these functions show below:
| Name | ClickHouse | Presto | Teradata | Snowflake | Oracle | Postgresql | Vertica | MySQL | RedShift | ElasticSearch | Impala | Druid | SyBase | DB2 | H2 | Exasol | Mariadb | Phoenix | Yellowbrick | Singlestore | Influxdata | Dolphindb | Intersystems |
|-------|-------|-------|-------|-------|-------|-------|-------|-------|-------|-------|-------|-------|-------|-------|-------|-------|-------|-------|-------|-------|-------|-------|-------|
| `VAR_POP` | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | No | Yes | Yes | No | Yes | Yes |
| `VAR_SAMP` | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | No |  Yes | Yes | Yes | No | Yes | Yes | No | Yes | Yes |
| `STDDEV_POP` | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes |
| `STDDEV_SAMP` | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | No |  Yes | Yes | Yes | Yes | Yes | Yes | No | Yes | Yes |
| `COVAR_POP` | Yes | Yes | Yes | Yes | Yes | Yes | Yes | No | No | No | No | No | Yes | No |  Yes | Yes | No | No | No | No | Yes | Yes | No |
| `COVAR_SAMP` | Yes | Yes | Yes | Yes | Yes | Yes | Yes | No | No | No | No | No | Yes | No |  Yes | Yes | No | No | No | No | No | No | No |
| `CORR` | Yes | Yes | Yes | Yes | Yes | Yes | Yes | No | No | No | No | No | Yes | No |  Yes | Yes | No | No | No | No | No | Yes | No |

Because some aggregate functions will be converted by Optimizer show below, this PR no need to match them.

|Input|Parsed|Optimized|
|------|--------------------|----------|
|`Every`| `aggregate.BoolAnd` |`Min`|
|`Any`| `aggregate.BoolOr` |`Max`|
|`Some`| `aggregate.BoolOr` |`Max`|

### Why are the changes needed?
Make the implement of `*Dialect` could extends the aggregate functions by override `JdbcDialect.compileAggregate`.

### Does this PR introduce _any_ user-facing change?
Yes. Users could pushdown more aggregate functions.

### How was this patch tested?
Exists tests.

Closes apache#35101 from beliefer/SPARK-37527-new2.

Authored-by: Jiaan Geng <beliefer@163.com>
Signed-off-by: Huaxin Gao <huaxin_gao@apple.com>
  • Loading branch information
beliefer authored and chenzhx committed Apr 18, 2022
1 parent 576b1fb commit cc970f1
Show file tree
Hide file tree
Showing 4 changed files with 113 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,13 @@
* The currently supported SQL aggregate functions:
* <ol>
* <li><pre>AVG(input1)</pre> Since 3.3.0</li>
* <li><pre>VAR_POP(input1)</pre> Since 3.3.0</li>
* <li><pre>VAR_SAMP(input1)</pre> Since 3.3.0</li>
* <li><pre>STDDEV_POP(input1)</pre> Since 3.3.0</li>
* <li><pre>STDDEV_SAMP(input1)</pre> Since 3.3.0</li>
* <li><pre>COVAR_POP(input1, input2)</pre> Since 3.3.0</li>
* <li><pre>COVAR_SAMP(input1, input2)</pre> Since 3.3.0</li>
* <li><pre>CORR(input1, input2)</pre> Since 3.3.0</li>
* </ol>
*
* @since 3.3.0
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -718,6 +718,27 @@ object DataSourceStrategy
Some(new Sum(FieldReference(name), aggregates.isDistinct))
case aggregate.Average(PushableColumnWithoutNestedColumn(name), _) =>
Some(new GeneralAggregateFunc("AVG", aggregates.isDistinct, Array(FieldReference(name))))
case aggregate.VariancePop(PushableColumnWithoutNestedColumn(name), _) =>
Some(new GeneralAggregateFunc("VAR_POP", aggregates.isDistinct, Array(FieldReference(name))))
case aggregate.VarianceSamp(PushableColumnWithoutNestedColumn(name), _) =>
Some(new GeneralAggregateFunc("VAR_SAMP", aggregates.isDistinct, Array(FieldReference(name))))
case aggregate.StddevPop(PushableColumnWithoutNestedColumn(name), _) =>
Some(new GeneralAggregateFunc("STDDEV_POP", aggregates.isDistinct, Array(FieldReference(name))))
case aggregate.StddevSamp(PushableColumnWithoutNestedColumn(name), _) =>
Some(new GeneralAggregateFunc("STDDEV_SAMP", aggregates.isDistinct, Array(FieldReference(name))))
case aggregate.CovPopulation(PushableColumnWithoutNestedColumn(left),
PushableColumnWithoutNestedColumn(right), _) =>
Some(new GeneralAggregateFunc("COVAR_POP", aggregates.isDistinct,
Array(FieldReference(left), FieldReference(right))))
case aggregate.CovSample(PushableColumnWithoutNestedColumn(left),
PushableColumnWithoutNestedColumn(right), _) =>
Some(new GeneralAggregateFunc("COVAR_SAMP", aggregates.isDistinct,
Array(FieldReference(left), FieldReference(right))))
case aggregate.Corr(PushableColumnWithoutNestedColumn(left),
PushableColumnWithoutNestedColumn(right), _) =>
Some(new GeneralAggregateFunc("CORR", aggregates.isDistinct,
Array(FieldReference(left), FieldReference(right))))

case _ => None
}
} else {
Expand Down
25 changes: 25 additions & 0 deletions sql/core/src/main/scala/org/apache/spark/sql/jdbc/H2Dialect.scala
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,36 @@ import java.util.Locale

import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.analysis.{NoSuchNamespaceException, NoSuchTableException, TableAlreadyExistsException}
import org.apache.spark.sql.connector.expressions.aggregate.{AggregateFunc, GeneralAggregateFunc}

private object H2Dialect extends JdbcDialect {
override def canHandle(url: String): Boolean =
url.toLowerCase(Locale.ROOT).startsWith("jdbc:h2")

override def compileAggregate(aggFunction: AggregateFunc): Option[String] = {
super.compileAggregate(aggFunction).orElse(
aggFunction match {
case f: GeneralAggregateFunc if f.name() == "VAR_POP" =>
assert(f.inputs().length == 1)
val distinct = if (f.isDistinct) "DISTINCT " else ""
Some(s"VAR_POP($distinct${f.inputs().head})")
case f: GeneralAggregateFunc if f.name() == "VAR_SAMP" =>
assert(f.inputs().length == 1)
val distinct = if (f.isDistinct) "DISTINCT " else ""
Some(s"VAR_SAMP($distinct${f.inputs().head})")
case f: GeneralAggregateFunc if f.name() == "STDDEV_POP" =>
assert(f.inputs().length == 1)
val distinct = if (f.isDistinct) "DISTINCT " else ""
Some(s"STDDEV_POP($distinct${f.inputs().head})")
case f: GeneralAggregateFunc if f.name() == "STDDEV_SAMP" =>
assert(f.inputs().length == 1)
val distinct = if (f.isDistinct) "DISTINCT " else ""
Some(s"STDDEV_SAMP($distinct${f.inputs().head})")
case _ => None
}
)
}

override def classifyException(message: String, e: Throwable): AnalysisException = {
if (e.isInstanceOf[SQLException]) {
// Error codes are from https://www.h2database.com/javadoc/org/h2/api/ErrorCode.html
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -713,6 +713,66 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel
checkAnswer(df, Seq(Row(1, 2), Row(2, 2), Row(6, 1)))
}

test("scan with aggregate push-down: VAR_POP VAR_SAMP with filter and group by") {
val df = sql("select VAR_POP(bonus), VAR_SAMP(bonus) FROM h2.test.employee where dept > 0" +
" group by DePt")
checkFiltersRemoved(df)
checkAggregateRemoved(df)
df.queryExecution.optimizedPlan.collect {
case _: DataSourceV2ScanRelation =>
val expected_plan_fragment =
"PushedAggregates: [VAR_POP(BONUS), VAR_SAMP(BONUS)], " +
"PushedFilters: [IsNotNull(DEPT), GreaterThan(DEPT,0)], " +
"PushedGroupByColumns: [DEPT]"
checkKeywordsExistsInExplain(df, expected_plan_fragment)
}
checkAnswer(df, Seq(Row(10000d, 20000d), Row(2500d, 5000d), Row(0d, null)))
}

test("scan with aggregate push-down: STDDEV_POP STDDEV_SAMP with filter and group by") {
val df = sql("select STDDEV_POP(bonus), STDDEV_SAMP(bonus) FROM h2.test.employee" +
" where dept > 0 group by DePt")
checkFiltersRemoved(df)
checkAggregateRemoved(df)
df.queryExecution.optimizedPlan.collect {
case _: DataSourceV2ScanRelation =>
val expected_plan_fragment =
"PushedAggregates: [STDDEV_POP(BONUS), STDDEV_SAMP(BONUS)], " +
"PushedFilters: [IsNotNull(DEPT), GreaterThan(DEPT,0)], " +
"PushedGroupByColumns: [DEPT]"
checkKeywordsExistsInExplain(df, expected_plan_fragment)
}
checkAnswer(df, Seq(Row(100d, 141.4213562373095d), Row(50d, 70.71067811865476d), Row(0d, null)))
}

test("scan with aggregate push-down: COVAR_POP COVAR_SAMP with filter and group by") {
val df = sql("select COVAR_POP(bonus, bonus), COVAR_SAMP(bonus, bonus)" +
" FROM h2.test.employee where dept > 0 group by DePt")
checkFiltersRemoved(df)
checkAggregateRemoved(df, false)
df.queryExecution.optimizedPlan.collect {
case _: DataSourceV2ScanRelation =>
val expected_plan_fragment =
"PushedFilters: [IsNotNull(DEPT), GreaterThan(DEPT,0)]"
checkKeywordsExistsInExplain(df, expected_plan_fragment)
}
checkAnswer(df, Seq(Row(10000d, 20000d), Row(2500d, 5000d), Row(0d, null)))
}

test("scan with aggregate push-down: CORR with filter and group by") {
val df = sql("select CORR(bonus, bonus) FROM h2.test.employee where dept > 0" +
" group by DePt")
checkFiltersRemoved(df)
checkAggregateRemoved(df, false)
df.queryExecution.optimizedPlan.collect {
case _: DataSourceV2ScanRelation =>
val expected_plan_fragment =
"PushedFilters: [IsNotNull(DEPT), GreaterThan(DEPT,0)]"
checkKeywordsExistsInExplain(df, expected_plan_fragment)
}
checkAnswer(df, Seq(Row(1d), Row(1d), Row(null)))
}

test("scan with aggregate push-down: aggregate over alias NOT push down") {
val cols = Seq("a", "b", "c", "d")
val df1 = sql("select * from h2.test.employee").toDF(cols: _*)
Expand Down

0 comments on commit cc970f1

Please sign in to comment.