diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/GeneralScalarExpression.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/GeneralScalarExpression.java
index 15071c960a2ba..4edc3521d841d 100644
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/GeneralScalarExpression.java
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/GeneralScalarExpression.java
@@ -148,6 +148,54 @@
*
Since version: 3.3.0
*
*
+ * Name: SUBSTRING
+ *
+ * - SQL semantic:
SUBSTRING(str, pos[, len])
+ * - Since version: 3.4.0
+ *
+ *
+ * Name: UPPER
+ *
+ * - SQL semantic:
UPPER(expr)
+ * - Since version: 3.4.0
+ *
+ *
+ * Name: LOWER
+ *
+ * - SQL semantic:
LOWER(expr)
+ * - Since version: 3.4.0
+ *
+ *
+ * Name: TRANSLATE
+ *
+ * - SQL semantic:
TRANSLATE(input, from, to)
+ * - Since version: 3.4.0
+ *
+ *
+ * Name: TRIM
+ *
+ * - SQL semantic:
TRIM(src, trim)
+ * - Since version: 3.4.0
+ *
+ *
+ * Name: LTRIM
+ *
+ * - SQL semantic:
LTRIM(src, trim)
+ * - Since version: 3.4.0
+ *
+ *
+ * Name: RTRIM
+ *
+ * - SQL semantic:
RTRIM(src, trim)
+ * - Since version: 3.4.0
+ *
+ *
+ * Name: OVERLAY
+ *
+ * - SQL semantic:
OVERLAY(string, replace, position[, length])
+ * - Since version: 3.4.0
+ *
+ *
*
* Note: SQL semantic conforms ANSI standard, so some expressions are not supported when ANSI off,
* including: add, subtract, multiply, divide, remainder, pmod.
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/util/V2ExpressionSQLBuilder.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/util/V2ExpressionSQLBuilder.java
index c9dfa2003e3c1..396b1d9cdd034 100644
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/util/V2ExpressionSQLBuilder.java
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/util/V2ExpressionSQLBuilder.java
@@ -102,6 +102,10 @@ public String build(Expression expr) {
case "FLOOR":
case "CEIL":
case "WIDTH_BUCKET":
+ case "SUBSTRING":
+ case "UPPER":
+ case "LOWER":
+ case "TRANSLATE":
return visitSQLFunction(name,
Arrays.stream(e.children()).map(c -> build(c)).toArray(String[]::new));
case "CASE_WHEN": {
@@ -109,6 +113,18 @@ public String build(Expression expr) {
Arrays.stream(e.children()).map(c -> build(c)).collect(Collectors.toList());
return visitCaseWhen(children.toArray(new String[e.children().length]));
}
+ case "TRIM":
+ return visitTrim("BOTH",
+ Arrays.stream(e.children()).map(c -> build(c)).toArray(String[]::new));
+ case "LTRIM":
+ return visitTrim("LEADING",
+ Arrays.stream(e.children()).map(c -> build(c)).toArray(String[]::new));
+ case "RTRIM":
+ return visitTrim("TRAILING",
+ Arrays.stream(e.children()).map(c -> build(c)).toArray(String[]::new));
+ case "OVERLAY":
+ return visitOverlay(
+ Arrays.stream(e.children()).map(c -> build(c)).toArray(String[]::new));
// TODO supports other expressions
default:
return visitUnexpectedExpr(expr);
@@ -228,4 +244,23 @@ protected String visitSQLFunction(String funcName, String[] inputs) {
protected String visitUnexpectedExpr(Expression expr) throws IllegalArgumentException {
throw new IllegalArgumentException("Unexpected V2 expression: " + expr);
}
+
+ protected String visitOverlay(String[] inputs) {
+ assert(inputs.length == 3 || inputs.length == 4);
+ if (inputs.length == 3) {
+ return "OVERLAY(" + inputs[0] + " PLACING " + inputs[1] + " FROM " + inputs[2] + ")";
+ } else {
+ return "OVERLAY(" + inputs[0] + " PLACING " + inputs[1] + " FROM " + inputs[2] +
+ " FOR " + inputs[3]+ ")";
+ }
+ }
+
+ protected String visitTrim(String direction, String[] inputs) {
+ assert(inputs.length == 1 || inputs.length == 2);
+ if (inputs.length == 1) {
+ return "TRIM(" + direction + " FROM " + inputs[0] + ")";
+ } else {
+ return "TRIM(" + direction + " " + inputs[1] + " FROM " + inputs[0] + ")";
+ }
+ }
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/catalyst/util/V2ExpressionBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/catalyst/util/V2ExpressionBuilder.scala
index 487b809d48a01..a2f35ec1152d8 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/catalyst/util/V2ExpressionBuilder.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/catalyst/util/V2ExpressionBuilder.scala
@@ -17,7 +17,7 @@
package org.apache.spark.sql.catalyst.util
-import org.apache.spark.sql.catalyst.expressions.{Abs, Add, And, BinaryComparison, BinaryOperator, BitwiseAnd, BitwiseNot, BitwiseOr, BitwiseXor, CaseWhen, Cast, Ceil, Coalesce, Contains, Divide, EndsWith, EqualTo, Exp, Expression, Floor, In, InSet, IsNotNull, IsNull, Literal, Log, Multiply, Not, Or, Pow, Predicate, Remainder, Sqrt, StartsWith, StringPredicate, Subtract, UnaryMinus, WidthBucket}
+import org.apache.spark.sql.catalyst.expressions.{Abs, Add, And, BinaryComparison, BinaryOperator, BitwiseAnd, BitwiseNot, BitwiseOr, BitwiseXor, CaseWhen, Cast, Ceil, Coalesce, Contains, Divide, EndsWith, EqualTo, Exp, Expression, Floor, In, InSet, IsNotNull, IsNull, Literal, Log, Lower, Multiply, Not, Or, Overlay, Pow, Predicate, Remainder, Sqrt, StartsWith, StringPredicate, StringTranslate, StringTrim, StringTrimLeft, StringTrimRight, Substring, Subtract, UnaryMinus, Upper, WidthBucket}
import org.apache.spark.sql.connector.expressions.{Cast => V2Cast, Expression => V2Expression, FieldReference, GeneralScalarExpression, LiteralValue}
import org.apache.spark.sql.connector.expressions.filter.{AlwaysFalse, AlwaysTrue, And => V2And, Not => V2Not, Or => V2Or, Predicate => V2Predicate}
import org.apache.spark.sql.execution.datasources.PushableColumn
@@ -200,6 +200,65 @@ class V2ExpressionBuilder(
} else {
None
}
+ case substring: Substring =>
+ val children = if (substring.len == Literal(Integer.MAX_VALUE)) {
+ Seq(substring.str, substring.pos)
+ } else {
+ substring.children
+ }
+ val childrenExpressions = children.flatMap(generateExpression(_))
+ if (childrenExpressions.length == children.length) {
+ Some(new GeneralScalarExpression("SUBSTRING",
+ childrenExpressions.toArray[V2Expression]))
+ } else {
+ None
+ }
+ case Upper(child) => generateExpression(child)
+ .map(v => new GeneralScalarExpression("UPPER", Array[V2Expression](v)))
+ case Lower(child) => generateExpression(child)
+ .map(v => new GeneralScalarExpression("LOWER", Array[V2Expression](v)))
+ case translate: StringTranslate =>
+ val childrenExpressions = translate.children.flatMap(generateExpression(_))
+ if (childrenExpressions.length == translate.children.length) {
+ Some(new GeneralScalarExpression("TRANSLATE",
+ childrenExpressions.toArray[V2Expression]))
+ } else {
+ None
+ }
+ case trim: StringTrim =>
+ val childrenExpressions = trim.children.flatMap(generateExpression(_))
+ if (childrenExpressions.length == trim.children.length) {
+ Some(new GeneralScalarExpression("TRIM", childrenExpressions.toArray[V2Expression]))
+ } else {
+ None
+ }
+ case trim: StringTrimLeft =>
+ val childrenExpressions = trim.children.flatMap(generateExpression(_))
+ if (childrenExpressions.length == trim.children.length) {
+ Some(new GeneralScalarExpression("LTRIM", childrenExpressions.toArray[V2Expression]))
+ } else {
+ None
+ }
+ case trim: StringTrimRight =>
+ val childrenExpressions = trim.children.flatMap(generateExpression(_))
+ if (childrenExpressions.length == trim.children.length) {
+ Some(new GeneralScalarExpression("RTRIM", childrenExpressions.toArray[V2Expression]))
+ } else {
+ None
+ }
+ case overlay: Overlay =>
+ val children = if (overlay.len == Literal(-1)) {
+ Seq(overlay.input, overlay.replace, overlay.pos)
+ } else {
+ overlay.children
+ }
+ val childrenExpressions = children.flatMap(generateExpression(_))
+ if (childrenExpressions.length == children.length) {
+ Some(new GeneralScalarExpression("OVERLAY",
+ childrenExpressions.toArray[V2Expression]))
+ } else {
+ None
+ }
// TODO supports other expressions
case _ => None
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/H2Dialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/H2Dialect.scala
index 4a88203ec59c9..7edcf1d519295 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/H2Dialect.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/H2Dialect.scala
@@ -31,7 +31,8 @@ private object H2Dialect extends JdbcDialect {
url.toLowerCase(Locale.ROOT).startsWith("jdbc:h2")
private val supportedFunctions =
- Set("ABS", "COALESCE", "LN", "EXP", "POWER", "SQRT", "FLOOR", "CEIL")
+ Set("ABS", "COALESCE", "LN", "EXP", "POWER", "SQRT", "FLOOR", "CEIL",
+ "SUBSTRING", "UPPER", "LOWER", "TRANSLATE", "TRIM")
override def isSupportedFunction(funcName: String): Boolean =
supportedFunctions.contains(funcName)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala
index e1883e4e7f4b8..9d37d8ef334b0 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala
@@ -251,6 +251,24 @@ abstract class JdbcDialect extends Serializable with Logging{
s"${this.getClass.getSimpleName} does not support function: $funcName")
}
}
+
+ override def visitOverlay(inputs: Array[String]): String = {
+ if (isSupportedFunction("OVERLAY")) {
+ super.visitOverlay(inputs)
+ } else {
+ throw new UnsupportedOperationException(
+ s"${this.getClass.getSimpleName} does not support function: OVERLAY")
+ }
+ }
+
+ override def visitTrim(direction: String, inputs: Array[String]): String = {
+ if (isSupportedFunction("TRIM")) {
+ super.visitTrim(direction, inputs)
+ } else {
+ throw new UnsupportedOperationException(
+ s"${this.getClass.getSimpleName} does not support function: TRIM")
+ }
+ }
}
/**
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala
index fd186b764fb4c..b9aa0486648d3 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala
@@ -675,6 +675,54 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel
}
}
+ test("scan with filter push-down with string functions") {
+ val df1 = sql("select * FROM h2.test.employee where " +
+ "substr(name, 2, 1) = 'e'" +
+ " AND upper(name) = 'JEN' AND lower(name) = 'jen' ")
+ checkFiltersRemoved(df1)
+ val expectedPlanFragment1 =
+ "PushedFilters: [NAME IS NOT NULL, (SUBSTRING(NAME, 2, 1)) = 'e', " +
+ "UPPER(NAME) = 'JEN', LOWER(NAME) = 'jen']"
+ checkPushedInfo(df1, expectedPlanFragment1)
+ checkAnswer(df1, Seq(Row(6, "jen", 12000, 1200, true)))
+
+ val df2 = sql("select * FROM h2.test.employee where " +
+ "trim(name) = 'jen' AND trim('j', name) = 'en'" +
+ "AND translate(name, 'e', 1) = 'j1n'")
+ checkFiltersRemoved(df2)
+ val expectedPlanFragment2 =
+ "PushedFilters: [NAME IS NOT NULL, TRIM(BOTH FROM NAME) = 'jen', " +
+ "(TRIM(BOTH 'j' FROM NAME)) = 'en', (TRANSLATE(NA..."
+ checkPushedInfo(df2, expectedPlanFragment2)
+ checkAnswer(df2, Seq(Row(6, "jen", 12000, 1200, true)))
+
+ val df3 = sql("select * FROM h2.test.employee where " +
+ "ltrim(name) = 'jen' AND ltrim('j', name) = 'en'")
+ checkFiltersRemoved(df3)
+ val expectedPlanFragment3 =
+ "PushedFilters: [TRIM(LEADING FROM NAME) = 'jen', " +
+ "(TRIM(LEADING 'j' FROM NAME)) = 'en']"
+ checkPushedInfo(df3, expectedPlanFragment3)
+ checkAnswer(df3, Seq(Row(6, "jen", 12000, 1200, true)))
+
+ val df4 = sql("select * FROM h2.test.employee where " +
+ "rtrim(name) = 'jen' AND rtrim('n', name) = 'je'")
+ checkFiltersRemoved(df4)
+ val expectedPlanFragment4 =
+ "PushedFilters: [TRIM(TRAILING FROM NAME) = 'jen', " +
+ "(TRIM(TRAILING 'n' FROM NAME)) = 'je']"
+ checkPushedInfo(df4, expectedPlanFragment4)
+ checkAnswer(df4, Seq(Row(6, "jen", 12000, 1200, true)))
+
+ // H2 does not support OVERLAY
+ val df5 = sql("select * FROM h2.test.employee where OVERLAY(NAME, '1', 2, 1) = 'j1n'")
+ checkFiltersRemoved(df5, false)
+ val expectedPlanFragment5 =
+ "PushedFilters: [NAME IS NOT NULL]"
+ checkPushedInfo(df5, expectedPlanFragment5)
+ checkAnswer(df5, Seq(Row(6, "jen", 12000, 1200, true)))
+ }
+
test("scan with aggregate push-down: MAX AVG with filter and group by") {
val df = sql("select MAX(SaLaRY), AVG(BONUS) FROM h2.test.employee where dept > 0" +
" group by DePt")