Skip to content

Commit

Permalink
[SPARK-38761][SQL] DS V2 supports push down misc non-aggregate functions
Browse files Browse the repository at this point in the history
### What changes were proposed in this pull request?
Currently, Spark have some misc non-aggregate functions of ANSI standard. Please refer https://github.com/apache/spark/blob/2f8613f22c0750c00cf1dcfb2f31c431d8dc1be7/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala#L362.
These functions show below:
`abs`,
`coalesce`,
`nullif`,
`CASE WHEN`
DS V2 should supports push down these misc non-aggregate functions.
Because DS V2 already support push down `CASE WHEN`, so this PR no need do the job again.
Because `nullif` extends `RuntimeReplaceable`, so this PR no need do the job too.

### Why are the changes needed?
DS V2 supports push down misc non-aggregate functions

### Does this PR introduce _any_ user-facing change?
'No'.
New feature.

### How was this patch tested?
New tests.

Closes #36039 from beliefer/SPARK-38761.

Authored-by: Jiaan Geng <beliefer@163.com>
Signed-off-by: Wenchen Fan <wenchen@databricks.com>
(cherry picked from commit 9ce4ba0)
Signed-off-by: Wenchen Fan <wenchen@databricks.com>
  • Loading branch information
beliefer authored and cloud-fan committed Apr 11, 2022
1 parent 5c3ef79 commit b3cd07b
Show file tree
Hide file tree
Showing 3 changed files with 44 additions and 25 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,10 @@ public String build(Expression expr) {
return visitNot(build(e.children()[0]));
case "~":
return visitUnaryArithmetic(name, inputToSQL(e.children()[0]));
case "ABS":
case "COALESCE":
return visitSQLFunction(name,
Arrays.stream(e.children()).map(c -> build(c)).toArray(String[]::new));
case "CASE_WHEN": {
List<String> children =
Arrays.stream(e.children()).map(c -> build(c)).collect(Collectors.toList());
Expand Down Expand Up @@ -210,6 +214,10 @@ protected String visitCaseWhen(String[] children) {
return sb.toString();
}

protected String visitSQLFunction(String funcName, String[] inputs) {
return funcName + "(" + Arrays.stream(inputs).collect(Collectors.joining(", ")) + ")";
}

protected String visitUnexpectedExpr(Expression expr) throws IllegalArgumentException {
throw new IllegalArgumentException("Unexpected V2 expression: " + expr);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

package org.apache.spark.sql.catalyst.util

import org.apache.spark.sql.catalyst.expressions.{Add, And, BinaryComparison, BinaryOperator, BitwiseAnd, BitwiseNot, BitwiseOr, BitwiseXor, CaseWhen, Cast, Contains, Divide, EndsWith, EqualTo, Expression, In, InSet, IsNotNull, IsNull, Literal, Multiply, Not, Or, Predicate, Remainder, StartsWith, StringPredicate, Subtract, UnaryMinus}
import org.apache.spark.sql.catalyst.expressions.{Abs, Add, And, BinaryComparison, BinaryOperator, BitwiseAnd, BitwiseNot, BitwiseOr, BitwiseXor, CaseWhen, Cast, Coalesce, Contains, Divide, EndsWith, EqualTo, Expression, In, InSet, IsNotNull, IsNull, Literal, Multiply, Not, Or, Predicate, Remainder, StartsWith, StringPredicate, Subtract, UnaryMinus}
import org.apache.spark.sql.connector.expressions.{Cast => V2Cast, Expression => V2Expression, FieldReference, GeneralScalarExpression, LiteralValue}
import org.apache.spark.sql.connector.expressions.filter.{AlwaysFalse, AlwaysTrue, And => V2And, Not => V2Not, Or => V2Or, Predicate => V2Predicate}
import org.apache.spark.sql.execution.datasources.PushableColumn
Expand Down Expand Up @@ -95,6 +95,15 @@ class V2ExpressionBuilder(
}
case Cast(child, dataType, _, true) =>
generateExpression(child).map(v => new V2Cast(v, dataType))
case Abs(child, true) => generateExpression(child)
.map(v => new GeneralScalarExpression("ABS", Array[V2Expression](v)))
case Coalesce(children) =>
val childrenExpressions = children.flatMap(generateExpression(_))
if (children.length == childrenExpressions.length) {
Some(new GeneralScalarExpression("COALESCE", childrenExpressions.toArray[V2Expression]))
} else {
None
}
case and: And =>
// AND expects predicate
val l = generateExpression(and.left, true)
Expand Down
50 changes: 26 additions & 24 deletions sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ import org.apache.spark.sql.catalyst.analysis.CannotReplaceMissingTableException
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Filter, Sort}
import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2ScanRelation, V1ScanWrapper}
import org.apache.spark.sql.execution.datasources.v2.jdbc.JDBCTableCatalog
import org.apache.spark.sql.functions.{avg, count, count_distinct, lit, not, sum, udf, when}
import org.apache.spark.sql.functions.{abs, avg, coalesce, count, count_distinct, lit, not, sum, udf, when}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.test.SharedSparkSession
import org.apache.spark.util.Utils
Expand Down Expand Up @@ -381,19 +381,13 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel
checkAnswer(df, Seq(Row("fred", 1), Row("mary", 2)))

val df2 = spark.table("h2.test.people").filter($"id" + Int.MaxValue > 1)

checkFiltersRemoved(df2, ansiMode)

df2.queryExecution.optimizedPlan.collect {
case _: DataSourceV2ScanRelation =>
val expected_plan_fragment = if (ansiMode) {
"PushedFilters: [ID IS NOT NULL, (ID + 2147483647) > 1], "
} else {
"PushedFilters: [ID IS NOT NULL], "
}
checkKeywordsExistsInExplain(df2, expected_plan_fragment)
val expectedPlanFragment2 = if (ansiMode) {
"PushedFilters: [ID IS NOT NULL, (ID + 2147483647) > 1], "
} else {
"PushedFilters: [ID IS NOT NULL], "
}

checkPushedInfo(df2, expectedPlanFragment2)
if (ansiMode) {
val e = intercept[SparkException] {
checkAnswer(df2, Seq.empty)
Expand Down Expand Up @@ -422,22 +416,30 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel

val df4 = spark.table("h2.test.employee")
.filter(($"salary" > 1000d).and($"salary" < 12000d))

checkFiltersRemoved(df4, ansiMode)

df4.queryExecution.optimizedPlan.collect {
case _: DataSourceV2ScanRelation =>
val expected_plan_fragment = if (ansiMode) {
"PushedFilters: [SALARY IS NOT NULL, " +
"CAST(SALARY AS double) > 1000.0, CAST(SALARY AS double) < 12000.0], "
} else {
"PushedFilters: [SALARY IS NOT NULL], "
}
checkKeywordsExistsInExplain(df4, expected_plan_fragment)
val expectedPlanFragment4 = if (ansiMode) {
"PushedFilters: [SALARY IS NOT NULL, " +
"CAST(SALARY AS double) > 1000.0, CAST(SALARY AS double) < 12000.0], "
} else {
"PushedFilters: [SALARY IS NOT NULL], "
}

checkPushedInfo(df4, expectedPlanFragment4)
checkAnswer(df4, Seq(Row(1, "amy", 10000, 1000, true),
Row(1, "cathy", 9000, 1200, false), Row(2, "david", 10000, 1300, true)))

val df5 = spark.table("h2.test.employee")
.filter(abs($"dept" - 3) > 1)
.filter(coalesce($"salary", $"bonus") > 2000)
checkFiltersRemoved(df5, ansiMode)
val expectedPlanFragment5 = if (ansiMode) {
"PushedFilters: [DEPT IS NOT NULL, ABS(DEPT - 3) > 1, " +
"(COALESCE(CAST(SALARY AS double), BONUS)) > 2000.0]"
} else {
"PushedFilters: [DEPT IS NOT NULL]"
}
checkPushedInfo(df5, expectedPlanFragment5)
checkAnswer(df5, Seq(Row(1, "amy", 10000, 1000, true),
Row(1, "cathy", 9000, 1200, false), Row(6, "jen", 12000, 1200, true)))
}
}
}
Expand Down

0 comments on commit b3cd07b

Please sign in to comment.