Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-37527][SQL] Translate more standard aggregate functions for pushdown #35101

Closed
wants to merge 8 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,13 @@
* The currently supported SQL aggregate functions:
* <ol>
* <li><pre>AVG(input1)</pre> Since 3.3.0</li>
* <li><pre>VAR_POP(input1)</pre> Since 3.3.0</li>
* <li><pre>VAR_SAMP(input1)</pre> Since 3.3.0</li>
* <li><pre>STDDEV_POP(input1)</pre> Since 3.3.0</li>
* <li><pre>STDDEV_SAMP(input1)</pre> Since 3.3.0</li>
* <li><pre>COVAR_POP(input1, input2)</pre> Since 3.3.0</li>
* <li><pre>COVAR_SAMP(input1, input2)</pre> Since 3.3.0</li>
* <li><pre>CORR(input1, input2)</pre> Since 3.3.0</li>
* </ol>
*
* @since 3.3.0
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -721,6 +721,26 @@ object DataSourceStrategy
Some(new Sum(FieldReference(name), agg.isDistinct))
case aggregate.Average(PushableColumnWithoutNestedColumn(name), _) =>
Some(new GeneralAggregateFunc("AVG", agg.isDistinct, Array(FieldReference(name))))
case aggregate.VariancePop(PushableColumnWithoutNestedColumn(name), _) =>
Some(new GeneralAggregateFunc("VAR_POP", agg.isDistinct, Array(FieldReference(name))))
case aggregate.VarianceSamp(PushableColumnWithoutNestedColumn(name), _) =>
Some(new GeneralAggregateFunc("VAR_SAMP", agg.isDistinct, Array(FieldReference(name))))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FYI: I am adding backticks around name in #35108. You might want to do the same.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the remind.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, shouldn't ask you to change yet. I made some changes again 49348a9

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK. I reverted.

case aggregate.StddevPop(PushableColumnWithoutNestedColumn(name), _) =>
Some(new GeneralAggregateFunc("STDDEV_POP", agg.isDistinct, Array(FieldReference(name))))
case aggregate.StddevSamp(PushableColumnWithoutNestedColumn(name), _) =>
Some(new GeneralAggregateFunc("STDDEV_SAMP", agg.isDistinct, Array(FieldReference(name))))
case aggregate.CovPopulation(PushableColumnWithoutNestedColumn(left),
PushableColumnWithoutNestedColumn(right), _) =>
Some(new GeneralAggregateFunc("COVAR_POP", agg.isDistinct,
Array(FieldReference(left), FieldReference(right))))
case aggregate.CovSample(PushableColumnWithoutNestedColumn(left),
PushableColumnWithoutNestedColumn(right), _) =>
Some(new GeneralAggregateFunc("COVAR_SAMP", agg.isDistinct,
Array(FieldReference(left), FieldReference(right))))
case aggregate.Corr(PushableColumnWithoutNestedColumn(left),
PushableColumnWithoutNestedColumn(right), _) =>
Some(new GeneralAggregateFunc("CORR", agg.isDistinct,
Array(FieldReference(left), FieldReference(right))))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please update the classdoc of GeneralAggregateFunc to include these new functions.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK

case _ => None
}
} else {
Expand Down
25 changes: 25 additions & 0 deletions sql/core/src/main/scala/org/apache/spark/sql/jdbc/H2Dialect.scala
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,36 @@ import java.util.Locale

import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.analysis.{NoSuchNamespaceException, NoSuchTableException, TableAlreadyExistsException}
import org.apache.spark.sql.connector.expressions.aggregate.{AggregateFunc, GeneralAggregateFunc}

private object H2Dialect extends JdbcDialect {
override def canHandle(url: String): Boolean =
url.toLowerCase(Locale.ROOT).startsWith("jdbc:h2")

override def compileAggregate(aggFunction: AggregateFunc): Option[String] = {
super.compileAggregate(aggFunction).orElse(
aggFunction match {
case f: GeneralAggregateFunc if f.name() == "VAR_POP" =>
assert(f.inputs().length == 1)
val distinct = if (f.isDistinct) "DISTINCT " else ""
Some(s"VAR_POP($distinct${f.inputs().head})")
case f: GeneralAggregateFunc if f.name() == "VAR_SAMP" =>
assert(f.inputs().length == 1)
val distinct = if (f.isDistinct) "DISTINCT " else ""
Some(s"VAR_SAMP($distinct${f.inputs().head})")
case f: GeneralAggregateFunc if f.name() == "STDDEV_POP" =>
assert(f.inputs().length == 1)
val distinct = if (f.isDistinct) "DISTINCT " else ""
Some(s"STDDEV_POP($distinct${f.inputs().head})")
case f: GeneralAggregateFunc if f.name() == "STDDEV_SAMP" =>
assert(f.inputs().length == 1)
val distinct = if (f.isDistinct) "DISTINCT " else ""
Some(s"STDDEV_SAMP($distinct${f.inputs().head})")
case _ => None
}
)
}

override def classifyException(message: String, e: Throwable): AnalysisException = {
if (e.isInstanceOf[SQLException]) {
// Error codes are from https://www.h2database.com/javadoc/org/h2/api/ErrorCode.html
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -698,6 +698,66 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel
checkAnswer(df, Seq(Row(1, 2), Row(2, 2), Row(6, 1)))
}

test("scan with aggregate push-down: VAR_POP VAR_SAMP with filter and group by") {
val df = sql("select VAR_POP(bonus), VAR_SAMP(bonus) FROM h2.test.employee where dept > 0" +
" group by DePt")
checkFiltersRemoved(df)
checkAggregateRemoved(df)
df.queryExecution.optimizedPlan.collect {
case _: DataSourceV2ScanRelation =>
val expected_plan_fragment =
"PushedAggregates: [VAR_POP(BONUS), VAR_SAMP(BONUS)], " +
"PushedFilters: [IsNotNull(DEPT), GreaterThan(DEPT,0)], " +
"PushedGroupByColumns: [DEPT]"
checkKeywordsExistsInExplain(df, expected_plan_fragment)
}
checkAnswer(df, Seq(Row(10000d, 20000d), Row(2500d, 5000d), Row(0d, null)))
}

test("scan with aggregate push-down: STDDEV_POP STDDEV_SAMP with filter and group by") {
val df = sql("select STDDEV_POP(bonus), STDDEV_SAMP(bonus) FROM h2.test.employee" +
" where dept > 0 group by DePt")
checkFiltersRemoved(df)
checkAggregateRemoved(df)
df.queryExecution.optimizedPlan.collect {
case _: DataSourceV2ScanRelation =>
val expected_plan_fragment =
"PushedAggregates: [STDDEV_POP(BONUS), STDDEV_SAMP(BONUS)], " +
"PushedFilters: [IsNotNull(DEPT), GreaterThan(DEPT,0)], " +
"PushedGroupByColumns: [DEPT]"
checkKeywordsExistsInExplain(df, expected_plan_fragment)
}
checkAnswer(df, Seq(Row(100d, 141.4213562373095d), Row(50d, 70.71067811865476d), Row(0d, null)))
}

test("scan with aggregate push-down: COVAR_POP COVAR_SAMP with filter and group by") {
val df = sql("select COVAR_POP(bonus, bonus), COVAR_SAMP(bonus, bonus)" +
" FROM h2.test.employee where dept > 0 group by DePt")
checkFiltersRemoved(df)
checkAggregateRemoved(df, false)
df.queryExecution.optimizedPlan.collect {
case _: DataSourceV2ScanRelation =>
val expected_plan_fragment =
"PushedFilters: [IsNotNull(DEPT), GreaterThan(DEPT,0)]"
checkKeywordsExistsInExplain(df, expected_plan_fragment)
}
checkAnswer(df, Seq(Row(10000d, 20000d), Row(2500d, 5000d), Row(0d, null)))
}

test("scan with aggregate push-down: CORR with filter and group by") {
val df = sql("select CORR(bonus, bonus) FROM h2.test.employee where dept > 0" +
" group by DePt")
checkFiltersRemoved(df)
checkAggregateRemoved(df, false)
df.queryExecution.optimizedPlan.collect {
case _: DataSourceV2ScanRelation =>
val expected_plan_fragment =
"PushedFilters: [IsNotNull(DEPT), GreaterThan(DEPT,0)]"
checkKeywordsExistsInExplain(df, expected_plan_fragment)
}
checkAnswer(df, Seq(Row(1d), Row(1d), Row(null)))
}

test("scan with aggregate push-down: aggregate over alias NOT push down") {
val cols = Seq("a", "b", "c", "d")
val df1 = sql("select * from h2.test.employee").toDF(cols: _*)
Expand Down