diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala index b4a8bafe22dfb..40998080bc4e3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala @@ -99,7 +99,7 @@ package object dsl { } def like(other: Expression, escapeChar: Char = '\\'): Expression = - Like(expr, other, escapeChar) + Like(expr, other, Literal(escapeChar.toString)) def rlike(other: Expression): Expression = RLike(expr, other) def contains(other: Expression): Expression = Contains(expr, other) def startsWith(other: Expression): Expression = StartsWith(expr, other) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala index 2354087768615..c9ddc70bf5bc6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala @@ -22,6 +22,7 @@ import java.util.regex.{MatchResult, Pattern} import org.apache.commons.text.StringEscapeUtils +import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.catalyst.util.{GenericArrayData, StringUtils} @@ -29,18 +30,20 @@ import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String -abstract class StringRegexExpression extends BinaryExpression +trait StringRegexExpression extends Expression with ImplicitCastInputTypes with NullIntolerant { + def str: Expression + def pattern: Expression + def escape(v: String): String def matches(regex: Pattern, str: String): Boolean override def dataType: DataType = BooleanType - override def inputTypes: Seq[DataType] = Seq(StringType, StringType) // try cache the pattern for Literal - private lazy val cache: Pattern = right match { - case x @ Literal(value: String, StringType) => compile(value) + private lazy val cache: Pattern = pattern match { + case Literal(value: String, StringType) => compile(value) case _ => null } @@ -51,10 +54,9 @@ abstract class StringRegexExpression extends BinaryExpression Pattern.compile(escape(str)) } - protected def pattern(str: String) = if (cache == null) compile(str) else cache - - protected override def nullSafeEval(input1: Any, input2: Any): Any = { - val regex = pattern(input2.asInstanceOf[UTF8String].toString) + def nullSafeMatch(input1: Any, input2: Any): Any = { + val s = input2.asInstanceOf[UTF8String].toString + val regex = if (cache == null) compile(s) else cache if(regex == null) { null } else { @@ -62,7 +64,7 @@ abstract class StringRegexExpression extends BinaryExpression } } - override def sql: String = s"${left.sql} ${prettyName.toUpperCase(Locale.ROOT)} ${right.sql}" + override def sql: String = s"${str.sql} ${prettyName.toUpperCase(Locale.ROOT)} ${pattern.sql}" } // scalastyle:off line.contains.tab @@ -107,46 +109,65 @@ abstract class StringRegexExpression extends BinaryExpression true > SELECT '%SystemDrive%/Users/John' _FUNC_ '/%SystemDrive/%//Users%' ESCAPE '/'; true + > SELECT _FUNC_('_Apache Spark_', '__%Spark__', '_'); + true """, note = """ Use RLIKE to match with standard regular expressions. """, since = "1.0.0") // scalastyle:on line.contains.tab -case class Like(left: Expression, right: Expression, escapeChar: Char) - extends StringRegexExpression { +case class Like(str: Expression, pattern: Expression, escape: Expression) + extends TernaryExpression with StringRegexExpression { - def this(left: Expression, right: Expression) = this(left, right, '\\') + def this(str: Expression, pattern: Expression) = this(str, pattern, Literal("\\")) + + override def inputTypes: Seq[DataType] = Seq(StringType, StringType, StringType) + override def children: Seq[Expression] = Seq(str, pattern, escape) + + private lazy val escapeChar: Char = if (escape.foldable) { + escape.eval() match { + case s: UTF8String if s != null && s.numChars() == 1 => s.toString.charAt(0) + case s => throw new AnalysisException( + s"The 'escape' parameter must be a string literal of one char but it is $s.") + } + } else { + throw new AnalysisException("The 'escape' parameter must be a string literal.") + } override def escape(v: String): String = StringUtils.escapeLikeRegex(v, escapeChar) override def matches(regex: Pattern, str: String): Boolean = regex.matcher(str).matches() override def toString: String = escapeChar match { - case '\\' => s"$left LIKE $right" - case c => s"$left LIKE $right ESCAPE '$c'" + case '\\' => s"$str LIKE $pattern" + case c => s"$str LIKE $pattern ESCAPE '$c'" + } + + protected override def nullSafeEval(input1: Any, input2: Any, input3: Any): Any = { + nullSafeMatch(input1, input2) } override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val patternClass = classOf[Pattern].getName val escapeFunc = StringUtils.getClass.getName.stripSuffix("$") + ".escapeLikeRegex" - if (right.foldable) { - val rVal = right.eval() - if (rVal != null) { + if (pattern.foldable) { + val patternVal = pattern.eval() + if (patternVal != null) { val regexStr = - StringEscapeUtils.escapeJava(escape(rVal.asInstanceOf[UTF8String].toString())) - val pattern = ctx.addMutableState(patternClass, "patternLike", + StringEscapeUtils.escapeJava(escape(patternVal.asInstanceOf[UTF8String].toString())) + val compiledPattern = ctx.addMutableState(patternClass, "compiledPattern", v => s"""$v = $patternClass.compile("$regexStr");""") // We don't use nullSafeCodeGen here because we don't want to re-evaluate right again. - val eval = left.genCode(ctx) + val eval = str.genCode(ctx) ev.copy(code = code""" ${eval.code} boolean ${ev.isNull} = ${eval.isNull}; ${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; if (!${ev.isNull}) { - ${ev.value} = $pattern.matcher(${eval.value}.toString()).matches(); + ${ev.value} = $compiledPattern.matcher(${eval.value}.toString()).matches(); } """) } else { @@ -164,18 +185,18 @@ case class Like(left: Expression, right: Expression, escapeChar: Char) } else { escapeChar } - val rightStr = ctx.freshName("rightStr") - val pattern = ctx.addMutableState(patternClass, "pattern") - val lastRightStr = ctx.addMutableState(classOf[String].getName, "lastRightStr") + val patternStr = ctx.freshName("patternStr") + val compiledPattern = ctx.addMutableState(patternClass, "compiledPattern") + val lastPatternStr = ctx.addMutableState(classOf[String].getName, "lastPatternStr") - nullSafeCodeGen(ctx, ev, (eval1, eval2) => { + nullSafeCodeGen(ctx, ev, (eval1, eval2, _) => { s""" - String $rightStr = $eval2.toString(); - if (!$rightStr.equals($lastRightStr)) { - $pattern = $patternClass.compile($escapeFunc($rightStr, '$newEscapeChar')); - $lastRightStr = $rightStr; + String $patternStr = $eval2.toString(); + if (!$patternStr.equals($lastPatternStr)) { + $compiledPattern = $patternClass.compile($escapeFunc($patternStr, '$newEscapeChar')); + $lastPatternStr = $patternStr; } - ${ev.value} = $pattern.matcher($eval1.toString()).matches(); + ${ev.value} = $compiledPattern.matcher($eval1.toString()).matches(); """ }) } @@ -214,12 +235,20 @@ case class Like(left: Expression, right: Expression, escapeChar: Char) """, since = "1.0.0") // scalastyle:on line.contains.tab -case class RLike(left: Expression, right: Expression) extends StringRegexExpression { +case class RLike(left: Expression, right: Expression) + extends BinaryExpression with StringRegexExpression { + + override def inputTypes: Seq[DataType] = Seq(StringType, StringType) + + override def str: Expression = left + override def pattern: Expression = right override def escape(v: String): String = v override def matches(regex: Pattern, str: String): Boolean = regex.matcher(str).find(0) override def toString: String = s"$left RLIKE $right" + protected override def nullSafeEval(input1: Any, input2: Any): Any = nullSafeMatch(input1, input2) + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val patternClass = classOf[Pattern].getName diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index e1dca4e945397..967eca77145e2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -1392,9 +1392,9 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging throw new ParseException("Invalid escape string." + "Escape string must contains only one character.", ctx) } - str.charAt(0) + str }.getOrElse('\\') - invertIfNotDefined(Like(e, expression(ctx.pattern), escapeChar)) + invertIfNotDefined(Like(e, expression(ctx.pattern), Literal(escapeChar))) case SqlBaseParser.RLIKE => invertIfNotDefined(RLike(e, expression(ctx.pattern))) case SqlBaseParser.NULL if ctx.NOT != null => diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index 584768eff700b..7fce03658fc16 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -3562,6 +3562,21 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { checkAnswer(df.select("x").filter("exists(i, x -> x % d == 0)"), Seq(Row(1))) } + + test("the like function with the escape parameter") { + val df = Seq(("abc", "a_c", "!")).toDF("str", "pattern", "escape") + checkAnswer(df.selectExpr("like(str, pattern, '@')"), Row(true)) + + val longEscapeError = intercept[AnalysisException] { + df.selectExpr("like(str, pattern, '@%')").collect() + }.getMessage + assert(longEscapeError.contains("The 'escape' parameter must be a string literal of one char")) + + val nonFoldableError = intercept[AnalysisException] { + df.selectExpr("like(str, pattern, escape)").collect() + }.getMessage + assert(nonFoldableError.contains("The 'escape' parameter must be a string literal")) + } } object DataFrameFunctionsSuite {