diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/GeneralScalarExpression.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/GeneralScalarExpression.java index 15071c960a2ba..4edc3521d841d 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/GeneralScalarExpression.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/GeneralScalarExpression.java @@ -148,6 +148,54 @@ *
  • Since version: 3.3.0
  • * * + *
  • Name: SUBSTRING + * + *
  • + *
  • Name: UPPER + * + *
  • + *
  • Name: LOWER + * + *
  • + *
  • Name: TRANSLATE + * + *
  • + *
  • Name: TRIM + * + *
  • + *
  • Name: LTRIM + * + *
  • + *
  • Name: RTRIM + * + *
  • + *
  • Name: OVERLAY + * + *
  • * * Note: SQL semantic conforms ANSI standard, so some expressions are not supported when ANSI off, * including: add, subtract, multiply, divide, remainder, pmod. diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/util/V2ExpressionSQLBuilder.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/util/V2ExpressionSQLBuilder.java index c9dfa2003e3c1..396b1d9cdd034 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/util/V2ExpressionSQLBuilder.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/util/V2ExpressionSQLBuilder.java @@ -102,6 +102,10 @@ public String build(Expression expr) { case "FLOOR": case "CEIL": case "WIDTH_BUCKET": + case "SUBSTRING": + case "UPPER": + case "LOWER": + case "TRANSLATE": return visitSQLFunction(name, Arrays.stream(e.children()).map(c -> build(c)).toArray(String[]::new)); case "CASE_WHEN": { @@ -109,6 +113,18 @@ public String build(Expression expr) { Arrays.stream(e.children()).map(c -> build(c)).collect(Collectors.toList()); return visitCaseWhen(children.toArray(new String[e.children().length])); } + case "TRIM": + return visitTrim("BOTH", + Arrays.stream(e.children()).map(c -> build(c)).toArray(String[]::new)); + case "LTRIM": + return visitTrim("LEADING", + Arrays.stream(e.children()).map(c -> build(c)).toArray(String[]::new)); + case "RTRIM": + return visitTrim("TRAILING", + Arrays.stream(e.children()).map(c -> build(c)).toArray(String[]::new)); + case "OVERLAY": + return visitOverlay( + Arrays.stream(e.children()).map(c -> build(c)).toArray(String[]::new)); // TODO supports other expressions default: return visitUnexpectedExpr(expr); @@ -228,4 +244,23 @@ protected String visitSQLFunction(String funcName, String[] inputs) { protected String visitUnexpectedExpr(Expression expr) throws IllegalArgumentException { throw new IllegalArgumentException("Unexpected V2 expression: " + expr); } + + protected String visitOverlay(String[] inputs) { + assert(inputs.length == 3 || inputs.length == 4); + if (inputs.length == 3) { + return "OVERLAY(" + inputs[0] + " PLACING " + inputs[1] + " FROM " + inputs[2] + ")"; + } else { + return "OVERLAY(" + inputs[0] + " PLACING " + inputs[1] + " FROM " + inputs[2] + + " FOR " + inputs[3]+ ")"; + } + } + + protected String visitTrim(String direction, String[] inputs) { + assert(inputs.length == 1 || inputs.length == 2); + if (inputs.length == 1) { + return "TRIM(" + direction + " FROM " + inputs[0] + ")"; + } else { + return "TRIM(" + direction + " " + inputs[1] + " FROM " + inputs[0] + ")"; + } + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/catalyst/util/V2ExpressionBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/catalyst/util/V2ExpressionBuilder.scala index 487b809d48a01..a2f35ec1152d8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/catalyst/util/V2ExpressionBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/catalyst/util/V2ExpressionBuilder.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.catalyst.util -import org.apache.spark.sql.catalyst.expressions.{Abs, Add, And, BinaryComparison, BinaryOperator, BitwiseAnd, BitwiseNot, BitwiseOr, BitwiseXor, CaseWhen, Cast, Ceil, Coalesce, Contains, Divide, EndsWith, EqualTo, Exp, Expression, Floor, In, InSet, IsNotNull, IsNull, Literal, Log, Multiply, Not, Or, Pow, Predicate, Remainder, Sqrt, StartsWith, StringPredicate, Subtract, UnaryMinus, WidthBucket} +import org.apache.spark.sql.catalyst.expressions.{Abs, Add, And, BinaryComparison, BinaryOperator, BitwiseAnd, BitwiseNot, BitwiseOr, BitwiseXor, CaseWhen, Cast, Ceil, Coalesce, Contains, Divide, EndsWith, EqualTo, Exp, Expression, Floor, In, InSet, IsNotNull, IsNull, Literal, Log, Lower, Multiply, Not, Or, Overlay, Pow, Predicate, Remainder, Sqrt, StartsWith, StringPredicate, StringTranslate, StringTrim, StringTrimLeft, StringTrimRight, Substring, Subtract, UnaryMinus, Upper, WidthBucket} import org.apache.spark.sql.connector.expressions.{Cast => V2Cast, Expression => V2Expression, FieldReference, GeneralScalarExpression, LiteralValue} import org.apache.spark.sql.connector.expressions.filter.{AlwaysFalse, AlwaysTrue, And => V2And, Not => V2Not, Or => V2Or, Predicate => V2Predicate} import org.apache.spark.sql.execution.datasources.PushableColumn @@ -200,6 +200,65 @@ class V2ExpressionBuilder( } else { None } + case substring: Substring => + val children = if (substring.len == Literal(Integer.MAX_VALUE)) { + Seq(substring.str, substring.pos) + } else { + substring.children + } + val childrenExpressions = children.flatMap(generateExpression(_)) + if (childrenExpressions.length == children.length) { + Some(new GeneralScalarExpression("SUBSTRING", + childrenExpressions.toArray[V2Expression])) + } else { + None + } + case Upper(child) => generateExpression(child) + .map(v => new GeneralScalarExpression("UPPER", Array[V2Expression](v))) + case Lower(child) => generateExpression(child) + .map(v => new GeneralScalarExpression("LOWER", Array[V2Expression](v))) + case translate: StringTranslate => + val childrenExpressions = translate.children.flatMap(generateExpression(_)) + if (childrenExpressions.length == translate.children.length) { + Some(new GeneralScalarExpression("TRANSLATE", + childrenExpressions.toArray[V2Expression])) + } else { + None + } + case trim: StringTrim => + val childrenExpressions = trim.children.flatMap(generateExpression(_)) + if (childrenExpressions.length == trim.children.length) { + Some(new GeneralScalarExpression("TRIM", childrenExpressions.toArray[V2Expression])) + } else { + None + } + case trim: StringTrimLeft => + val childrenExpressions = trim.children.flatMap(generateExpression(_)) + if (childrenExpressions.length == trim.children.length) { + Some(new GeneralScalarExpression("LTRIM", childrenExpressions.toArray[V2Expression])) + } else { + None + } + case trim: StringTrimRight => + val childrenExpressions = trim.children.flatMap(generateExpression(_)) + if (childrenExpressions.length == trim.children.length) { + Some(new GeneralScalarExpression("RTRIM", childrenExpressions.toArray[V2Expression])) + } else { + None + } + case overlay: Overlay => + val children = if (overlay.len == Literal(-1)) { + Seq(overlay.input, overlay.replace, overlay.pos) + } else { + overlay.children + } + val childrenExpressions = children.flatMap(generateExpression(_)) + if (childrenExpressions.length == children.length) { + Some(new GeneralScalarExpression("OVERLAY", + childrenExpressions.toArray[V2Expression])) + } else { + None + } // TODO supports other expressions case _ => None } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/H2Dialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/H2Dialect.scala index 4a88203ec59c9..7edcf1d519295 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/H2Dialect.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/H2Dialect.scala @@ -31,7 +31,8 @@ private object H2Dialect extends JdbcDialect { url.toLowerCase(Locale.ROOT).startsWith("jdbc:h2") private val supportedFunctions = - Set("ABS", "COALESCE", "LN", "EXP", "POWER", "SQRT", "FLOOR", "CEIL") + Set("ABS", "COALESCE", "LN", "EXP", "POWER", "SQRT", "FLOOR", "CEIL", + "SUBSTRING", "UPPER", "LOWER", "TRANSLATE", "TRIM") override def isSupportedFunction(funcName: String): Boolean = supportedFunctions.contains(funcName) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala index e1883e4e7f4b8..9d37d8ef334b0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala @@ -251,6 +251,24 @@ abstract class JdbcDialect extends Serializable with Logging{ s"${this.getClass.getSimpleName} does not support function: $funcName") } } + + override def visitOverlay(inputs: Array[String]): String = { + if (isSupportedFunction("OVERLAY")) { + super.visitOverlay(inputs) + } else { + throw new UnsupportedOperationException( + s"${this.getClass.getSimpleName} does not support function: OVERLAY") + } + } + + override def visitTrim(direction: String, inputs: Array[String]): String = { + if (isSupportedFunction("TRIM")) { + super.visitTrim(direction, inputs) + } else { + throw new UnsupportedOperationException( + s"${this.getClass.getSimpleName} does not support function: TRIM") + } + } } /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala index fd186b764fb4c..b9aa0486648d3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala @@ -675,6 +675,54 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel } } + test("scan with filter push-down with string functions") { + val df1 = sql("select * FROM h2.test.employee where " + + "substr(name, 2, 1) = 'e'" + + " AND upper(name) = 'JEN' AND lower(name) = 'jen' ") + checkFiltersRemoved(df1) + val expectedPlanFragment1 = + "PushedFilters: [NAME IS NOT NULL, (SUBSTRING(NAME, 2, 1)) = 'e', " + + "UPPER(NAME) = 'JEN', LOWER(NAME) = 'jen']" + checkPushedInfo(df1, expectedPlanFragment1) + checkAnswer(df1, Seq(Row(6, "jen", 12000, 1200, true))) + + val df2 = sql("select * FROM h2.test.employee where " + + "trim(name) = 'jen' AND trim('j', name) = 'en'" + + "AND translate(name, 'e', 1) = 'j1n'") + checkFiltersRemoved(df2) + val expectedPlanFragment2 = + "PushedFilters: [NAME IS NOT NULL, TRIM(BOTH FROM NAME) = 'jen', " + + "(TRIM(BOTH 'j' FROM NAME)) = 'en', (TRANSLATE(NA..." + checkPushedInfo(df2, expectedPlanFragment2) + checkAnswer(df2, Seq(Row(6, "jen", 12000, 1200, true))) + + val df3 = sql("select * FROM h2.test.employee where " + + "ltrim(name) = 'jen' AND ltrim('j', name) = 'en'") + checkFiltersRemoved(df3) + val expectedPlanFragment3 = + "PushedFilters: [TRIM(LEADING FROM NAME) = 'jen', " + + "(TRIM(LEADING 'j' FROM NAME)) = 'en']" + checkPushedInfo(df3, expectedPlanFragment3) + checkAnswer(df3, Seq(Row(6, "jen", 12000, 1200, true))) + + val df4 = sql("select * FROM h2.test.employee where " + + "rtrim(name) = 'jen' AND rtrim('n', name) = 'je'") + checkFiltersRemoved(df4) + val expectedPlanFragment4 = + "PushedFilters: [TRIM(TRAILING FROM NAME) = 'jen', " + + "(TRIM(TRAILING 'n' FROM NAME)) = 'je']" + checkPushedInfo(df4, expectedPlanFragment4) + checkAnswer(df4, Seq(Row(6, "jen", 12000, 1200, true))) + + // H2 does not support OVERLAY + val df5 = sql("select * FROM h2.test.employee where OVERLAY(NAME, '1', 2, 1) = 'j1n'") + checkFiltersRemoved(df5, false) + val expectedPlanFragment5 = + "PushedFilters: [NAME IS NOT NULL]" + checkPushedInfo(df5, expectedPlanFragment5) + checkAnswer(df5, Seq(Row(6, "jen", 12000, 1200, true))) + } + test("scan with aggregate push-down: MAX AVG with filter and group by") { val df = sql("select MAX(SaLaRY), AVG(BONUS) FROM h2.test.employee where dept > 0" + " group by DePt")