Skip to content

Commit

Permalink
[SPARK-38897][SQL] DS V2 supports push down string functions
Browse files Browse the repository at this point in the history
### What changes were proposed in this pull request?

Currently, Spark have some string functions of ANSI standard. Please refer

https://github.com/apache/spark/blob/2f8613f22c0750c00cf1dcfb2f31c431d8dc1be7/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala#L503

These functions show below:
`SUBSTRING,`
`UPPER`,
`LOWER`,
`TRANSLATE`,
`TRIM`,
`OVERLAY`

The mainstream databases support these functions show below.
Function | PostgreSQL | ClickHouse | H2 | MySQL | Oracle | Redshift | Presto | Teradata | Snowflake | DB2 | Vertica | Exasol | SqlServer | Yellowbrick | Impala | Mariadb | Druid | Pig | SQLite | Influxdata | Singlestore | ElasticSearch
-- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | --
`SUBSTRING` | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes
`UPPER` | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes
`LOWER` | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | YES | Yes | Yes | Yes | Yes
`TRIM` | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes
`TRANSLATE` | Yes | No | Yes | No | Yes | Yes | No | No | Yes | Yes | Yes | Yes | No | Yes | Yes | Yes | No | No | No | No | No | No
`OVERLAY` | Yes | No | No | No | Yes | No | No | No | No | Yes | Yes | No | No | No | No | No | No | No | No | No | No | No

DS V2 should supports push down these string functions.

### Why are the changes needed?

DS V2 supports push down string functions

### Does this PR introduce _any_ user-facing change?

'No'.
New feature.

### How was this patch tested?

New tests.

Closes #36330 from chenzhx/spark-master.

Authored-by: chenzhx <chen@apache.org>
Signed-off-by: Wenchen Fan <wenchen@databricks.com>
  • Loading branch information
chenzhx authored and cloud-fan committed May 23, 2022
1 parent a0decfc commit 724fb08
Show file tree
Hide file tree
Showing 6 changed files with 211 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,54 @@
* <li>Since version: 3.3.0</li>
* </ul>
* </li>
* <li>Name: <code>SUBSTRING</code>
* <ul>
* <li>SQL semantic: <code>SUBSTRING(str, pos[, len])</code></li>
* <li>Since version: 3.4.0</li>
* </ul>
* </li>
* <li>Name: <code>UPPER</code>
* <ul>
* <li>SQL semantic: <code>UPPER(expr)</code></li>
* <li>Since version: 3.4.0</li>
* </ul>
* </li>
* <li>Name: <code>LOWER</code>
* <ul>
* <li>SQL semantic: <code>LOWER(expr)</code></li>
* <li>Since version: 3.4.0</li>
* </ul>
* </li>
* <li>Name: <code>TRANSLATE</code>
* <ul>
* <li>SQL semantic: <code>TRANSLATE(input, from, to)</code></li>
* <li>Since version: 3.4.0</li>
* </ul>
* </li>
* <li>Name: <code>TRIM</code>
* <ul>
* <li>SQL semantic: <code>TRIM(src, trim)</code></li>
* <li>Since version: 3.4.0</li>
* </ul>
* </li>
* <li>Name: <code>LTRIM</code>
* <ul>
* <li>SQL semantic: <code>LTRIM(src, trim)</code></li>
* <li>Since version: 3.4.0</li>
* </ul>
* </li>
* <li>Name: <code>RTRIM</code>
* <ul>
* <li>SQL semantic: <code>RTRIM(src, trim)</code></li>
* <li>Since version: 3.4.0</li>
* </ul>
* </li>
* <li>Name: <code>OVERLAY</code>
* <ul>
* <li>SQL semantic: <code>OVERLAY(string, replace, position[, length])</code></li>
* <li>Since version: 3.4.0</li>
* </ul>
* </li>
* </ol>
* Note: SQL semantic conforms ANSI standard, so some expressions are not supported when ANSI off,
* including: add, subtract, multiply, divide, remainder, pmod.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -102,13 +102,29 @@ public String build(Expression expr) {
case "FLOOR":
case "CEIL":
case "WIDTH_BUCKET":
case "SUBSTRING":
case "UPPER":
case "LOWER":
case "TRANSLATE":
return visitSQLFunction(name,
Arrays.stream(e.children()).map(c -> build(c)).toArray(String[]::new));
case "CASE_WHEN": {
List<String> children =
Arrays.stream(e.children()).map(c -> build(c)).collect(Collectors.toList());
return visitCaseWhen(children.toArray(new String[e.children().length]));
}
case "TRIM":
return visitTrim("BOTH",
Arrays.stream(e.children()).map(c -> build(c)).toArray(String[]::new));
case "LTRIM":
return visitTrim("LEADING",
Arrays.stream(e.children()).map(c -> build(c)).toArray(String[]::new));
case "RTRIM":
return visitTrim("TRAILING",
Arrays.stream(e.children()).map(c -> build(c)).toArray(String[]::new));
case "OVERLAY":
return visitOverlay(
Arrays.stream(e.children()).map(c -> build(c)).toArray(String[]::new));
// TODO supports other expressions
default:
return visitUnexpectedExpr(expr);
Expand Down Expand Up @@ -228,4 +244,23 @@ protected String visitSQLFunction(String funcName, String[] inputs) {
protected String visitUnexpectedExpr(Expression expr) throws IllegalArgumentException {
throw new IllegalArgumentException("Unexpected V2 expression: " + expr);
}

protected String visitOverlay(String[] inputs) {
assert(inputs.length == 3 || inputs.length == 4);
if (inputs.length == 3) {
return "OVERLAY(" + inputs[0] + " PLACING " + inputs[1] + " FROM " + inputs[2] + ")";
} else {
return "OVERLAY(" + inputs[0] + " PLACING " + inputs[1] + " FROM " + inputs[2] +
" FOR " + inputs[3]+ ")";
}
}

protected String visitTrim(String direction, String[] inputs) {
assert(inputs.length == 1 || inputs.length == 2);
if (inputs.length == 1) {
return "TRIM(" + direction + " FROM " + inputs[0] + ")";
} else {
return "TRIM(" + direction + " " + inputs[1] + " FROM " + inputs[0] + ")";
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

package org.apache.spark.sql.catalyst.util

import org.apache.spark.sql.catalyst.expressions.{Abs, Add, And, BinaryComparison, BinaryOperator, BitwiseAnd, BitwiseNot, BitwiseOr, BitwiseXor, CaseWhen, Cast, Ceil, Coalesce, Contains, Divide, EndsWith, EqualTo, Exp, Expression, Floor, In, InSet, IsNotNull, IsNull, Literal, Log, Multiply, Not, Or, Pow, Predicate, Remainder, Sqrt, StartsWith, StringPredicate, Subtract, UnaryMinus, WidthBucket}
import org.apache.spark.sql.catalyst.expressions.{Abs, Add, And, BinaryComparison, BinaryOperator, BitwiseAnd, BitwiseNot, BitwiseOr, BitwiseXor, CaseWhen, Cast, Ceil, Coalesce, Contains, Divide, EndsWith, EqualTo, Exp, Expression, Floor, In, InSet, IsNotNull, IsNull, Literal, Log, Lower, Multiply, Not, Or, Overlay, Pow, Predicate, Remainder, Sqrt, StartsWith, StringPredicate, StringTranslate, StringTrim, StringTrimLeft, StringTrimRight, Substring, Subtract, UnaryMinus, Upper, WidthBucket}
import org.apache.spark.sql.connector.expressions.{Cast => V2Cast, Expression => V2Expression, FieldReference, GeneralScalarExpression, LiteralValue}
import org.apache.spark.sql.connector.expressions.filter.{AlwaysFalse, AlwaysTrue, And => V2And, Not => V2Not, Or => V2Or, Predicate => V2Predicate}
import org.apache.spark.sql.execution.datasources.PushableColumn
Expand Down Expand Up @@ -200,6 +200,65 @@ class V2ExpressionBuilder(
} else {
None
}
case substring: Substring =>
val children = if (substring.len == Literal(Integer.MAX_VALUE)) {
Seq(substring.str, substring.pos)
} else {
substring.children
}
val childrenExpressions = children.flatMap(generateExpression(_))
if (childrenExpressions.length == children.length) {
Some(new GeneralScalarExpression("SUBSTRING",
childrenExpressions.toArray[V2Expression]))
} else {
None
}
case Upper(child) => generateExpression(child)
.map(v => new GeneralScalarExpression("UPPER", Array[V2Expression](v)))
case Lower(child) => generateExpression(child)
.map(v => new GeneralScalarExpression("LOWER", Array[V2Expression](v)))
case translate: StringTranslate =>
val childrenExpressions = translate.children.flatMap(generateExpression(_))
if (childrenExpressions.length == translate.children.length) {
Some(new GeneralScalarExpression("TRANSLATE",
childrenExpressions.toArray[V2Expression]))
} else {
None
}
case trim: StringTrim =>
val childrenExpressions = trim.children.flatMap(generateExpression(_))
if (childrenExpressions.length == trim.children.length) {
Some(new GeneralScalarExpression("TRIM", childrenExpressions.toArray[V2Expression]))
} else {
None
}
case trim: StringTrimLeft =>
val childrenExpressions = trim.children.flatMap(generateExpression(_))
if (childrenExpressions.length == trim.children.length) {
Some(new GeneralScalarExpression("LTRIM", childrenExpressions.toArray[V2Expression]))
} else {
None
}
case trim: StringTrimRight =>
val childrenExpressions = trim.children.flatMap(generateExpression(_))
if (childrenExpressions.length == trim.children.length) {
Some(new GeneralScalarExpression("RTRIM", childrenExpressions.toArray[V2Expression]))
} else {
None
}
case overlay: Overlay =>
val children = if (overlay.len == Literal(-1)) {
Seq(overlay.input, overlay.replace, overlay.pos)
} else {
overlay.children
}
val childrenExpressions = children.flatMap(generateExpression(_))
if (childrenExpressions.length == children.length) {
Some(new GeneralScalarExpression("OVERLAY",
childrenExpressions.toArray[V2Expression]))
} else {
None
}
// TODO supports other expressions
case _ => None
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,8 @@ private object H2Dialect extends JdbcDialect {
url.toLowerCase(Locale.ROOT).startsWith("jdbc:h2")

private val supportedFunctions =
Set("ABS", "COALESCE", "LN", "EXP", "POWER", "SQRT", "FLOOR", "CEIL")
Set("ABS", "COALESCE", "LN", "EXP", "POWER", "SQRT", "FLOOR", "CEIL",
"SUBSTRING", "UPPER", "LOWER", "TRANSLATE", "TRIM")

override def isSupportedFunction(funcName: String): Boolean =
supportedFunctions.contains(funcName)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -251,6 +251,24 @@ abstract class JdbcDialect extends Serializable with Logging{
s"${this.getClass.getSimpleName} does not support function: $funcName")
}
}

override def visitOverlay(inputs: Array[String]): String = {
if (isSupportedFunction("OVERLAY")) {
super.visitOverlay(inputs)
} else {
throw new UnsupportedOperationException(
s"${this.getClass.getSimpleName} does not support function: OVERLAY")
}
}

override def visitTrim(direction: String, inputs: Array[String]): String = {
if (isSupportedFunction("TRIM")) {
super.visitTrim(direction, inputs)
} else {
throw new UnsupportedOperationException(
s"${this.getClass.getSimpleName} does not support function: TRIM")
}
}
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -675,6 +675,54 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel
}
}

test("scan with filter push-down with string functions") {
val df1 = sql("select * FROM h2.test.employee where " +
"substr(name, 2, 1) = 'e'" +
" AND upper(name) = 'JEN' AND lower(name) = 'jen' ")
checkFiltersRemoved(df1)
val expectedPlanFragment1 =
"PushedFilters: [NAME IS NOT NULL, (SUBSTRING(NAME, 2, 1)) = 'e', " +
"UPPER(NAME) = 'JEN', LOWER(NAME) = 'jen']"
checkPushedInfo(df1, expectedPlanFragment1)
checkAnswer(df1, Seq(Row(6, "jen", 12000, 1200, true)))

val df2 = sql("select * FROM h2.test.employee where " +
"trim(name) = 'jen' AND trim('j', name) = 'en'" +
"AND translate(name, 'e', 1) = 'j1n'")
checkFiltersRemoved(df2)
val expectedPlanFragment2 =
"PushedFilters: [NAME IS NOT NULL, TRIM(BOTH FROM NAME) = 'jen', " +
"(TRIM(BOTH 'j' FROM NAME)) = 'en', (TRANSLATE(NA..."
checkPushedInfo(df2, expectedPlanFragment2)
checkAnswer(df2, Seq(Row(6, "jen", 12000, 1200, true)))

val df3 = sql("select * FROM h2.test.employee where " +
"ltrim(name) = 'jen' AND ltrim('j', name) = 'en'")
checkFiltersRemoved(df3)
val expectedPlanFragment3 =
"PushedFilters: [TRIM(LEADING FROM NAME) = 'jen', " +
"(TRIM(LEADING 'j' FROM NAME)) = 'en']"
checkPushedInfo(df3, expectedPlanFragment3)
checkAnswer(df3, Seq(Row(6, "jen", 12000, 1200, true)))

val df4 = sql("select * FROM h2.test.employee where " +
"rtrim(name) = 'jen' AND rtrim('n', name) = 'je'")
checkFiltersRemoved(df4)
val expectedPlanFragment4 =
"PushedFilters: [TRIM(TRAILING FROM NAME) = 'jen', " +
"(TRIM(TRAILING 'n' FROM NAME)) = 'je']"
checkPushedInfo(df4, expectedPlanFragment4)
checkAnswer(df4, Seq(Row(6, "jen", 12000, 1200, true)))

// H2 does not support OVERLAY
val df5 = sql("select * FROM h2.test.employee where OVERLAY(NAME, '1', 2, 1) = 'j1n'")
checkFiltersRemoved(df5, false)
val expectedPlanFragment5 =
"PushedFilters: [NAME IS NOT NULL]"
checkPushedInfo(df5, expectedPlanFragment5)
checkAnswer(df5, Seq(Row(6, "jen", 12000, 1200, true)))
}

test("scan with aggregate push-down: MAX AVG with filter and group by") {
val df = sql("select MAX(SaLaRY), AVG(BONUS) FROM h2.test.employee where dept > 0" +
" group by DePt")
Expand Down

0 comments on commit 724fb08

Please sign in to comment.