Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-30625][SQL] Support escape as third parameter of the like function #27355

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ package object dsl {
}

def like(other: Expression, escapeChar: Char = '\\'): Expression =
Like(expr, other, escapeChar)
Like(expr, other, Literal(escapeChar.toString))
def rlike(other: Expression): Expression = RLike(expr, other)
def contains(other: Expression): Expression = Contains(expr, other)
def startsWith(other: Expression): Expression = StartsWith(expr, other)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,25 +22,28 @@ import java.util.regex.{MatchResult, Pattern}

import org.apache.commons.text.StringEscapeUtils

import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
import org.apache.spark.sql.catalyst.util.{GenericArrayData, StringUtils}
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String


abstract class StringRegexExpression extends BinaryExpression
trait StringRegexExpression extends Expression
with ImplicitCastInputTypes with NullIntolerant {

def str: Expression
def pattern: Expression

def escape(v: String): String
def matches(regex: Pattern, str: String): Boolean

override def dataType: DataType = BooleanType
override def inputTypes: Seq[DataType] = Seq(StringType, StringType)

// try cache the pattern for Literal
private lazy val cache: Pattern = right match {
case x @ Literal(value: String, StringType) => compile(value)
private lazy val cache: Pattern = pattern match {
case Literal(value: String, StringType) => compile(value)
case _ => null
}

Expand All @@ -51,18 +54,17 @@ abstract class StringRegexExpression extends BinaryExpression
Pattern.compile(escape(str))
}

protected def pattern(str: String) = if (cache == null) compile(str) else cache

protected override def nullSafeEval(input1: Any, input2: Any): Any = {
val regex = pattern(input2.asInstanceOf[UTF8String].toString)
def nullSafeMatch(input1: Any, input2: Any): Any = {
val s = input2.asInstanceOf[UTF8String].toString
val regex = if (cache == null) compile(s) else cache
if(regex == null) {
null
} else {
matches(regex, input1.asInstanceOf[UTF8String].toString)
}
}

override def sql: String = s"${left.sql} ${prettyName.toUpperCase(Locale.ROOT)} ${right.sql}"
override def sql: String = s"${str.sql} ${prettyName.toUpperCase(Locale.ROOT)} ${pattern.sql}"
}

// scalastyle:off line.contains.tab
Expand Down Expand Up @@ -107,46 +109,65 @@ abstract class StringRegexExpression extends BinaryExpression
true
> SELECT '%SystemDrive%/Users/John' _FUNC_ '/%SystemDrive/%//Users%' ESCAPE '/';
true
> SELECT _FUNC_('_Apache Spark_', '__%Spark__', '_');
true
""",
note = """
Use RLIKE to match with standard regular expressions.
""",
since = "1.0.0")
// scalastyle:on line.contains.tab
case class Like(left: Expression, right: Expression, escapeChar: Char)
extends StringRegexExpression {
case class Like(str: Expression, pattern: Expression, escape: Expression)
extends TernaryExpression with StringRegexExpression {

def this(left: Expression, right: Expression) = this(left, right, '\\')
def this(str: Expression, pattern: Expression) = this(str, pattern, Literal("\\"))

override def inputTypes: Seq[DataType] = Seq(StringType, StringType, StringType)
override def children: Seq[Expression] = Seq(str, pattern, escape)

private lazy val escapeChar: Char = if (escape.foldable) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need lazy here? I personally think its better to check error conditions as soon as possible.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When I remove lazy, I get the exception:

makeCopy, tree: str#13580 LIKE pattern#13581 ESCAPE '@'
org.apache.spark.sql.catalyst.errors.package$TreeNodeException: makeCopy, tree: str#13580 LIKE pattern#13581 ESCAPE '@'
	at org.apache.spark.sql.catalyst.errors.package$.attachTree(package.scala:56)
	at org.apache.spark.sql.catalyst.trees.TreeNode.makeCopy(TreeNode.scala:435)
	at org.apache.spark.sql.catalyst.trees.TreeNode.mapChildren(TreeNode.scala:408)
...
Caused by: org.apache.spark.sql.AnalysisException: The 'escape' parameter must be a string literal.;
	at org.apache.spark.sql.catalyst.expressions.Like.<init>(regexpExpressions.scala:135)

See escape is PrettyAttribute:
Screen Shot 2020-01-26 at 19 51 34

Copy link
Member Author

@MaxGekk MaxGekk Jan 26, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I got the error when I run new test in DataFrameFunctionsSuite, and I remember I got the error on other tests as well. PrettyAttribute is not foldable.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ur, ok. Thanks for the check. The current one looks ok to me.

escape.eval() match {
case s: UTF8String if s != null && s.numChars() == 1 => s.toString.charAt(0)
case s => throw new AnalysisException(
s"The 'escape' parameter must be a string literal of one char but it is $s.")
}
} else {
throw new AnalysisException("The 'escape' parameter must be a string literal.")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add tests for this path and line 131 (error cases)?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This kind of thing should be done in checkInputDataTypes.

}

override def escape(v: String): String = StringUtils.escapeLikeRegex(v, escapeChar)

override def matches(regex: Pattern, str: String): Boolean = regex.matcher(str).matches()

override def toString: String = escapeChar match {
case '\\' => s"$left LIKE $right"
case c => s"$left LIKE $right ESCAPE '$c'"
case '\\' => s"$str LIKE $pattern"
case c => s"$str LIKE $pattern ESCAPE '$c'"
}

protected override def nullSafeEval(input1: Any, input2: Any, input3: Any): Any = {
nullSafeMatch(input1, input2)
}

override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
val patternClass = classOf[Pattern].getName
val escapeFunc = StringUtils.getClass.getName.stripSuffix("$") + ".escapeLikeRegex"

if (right.foldable) {
val rVal = right.eval()
if (rVal != null) {
if (pattern.foldable) {
val patternVal = pattern.eval()
if (patternVal != null) {
val regexStr =
StringEscapeUtils.escapeJava(escape(rVal.asInstanceOf[UTF8String].toString()))
val pattern = ctx.addMutableState(patternClass, "patternLike",
StringEscapeUtils.escapeJava(escape(patternVal.asInstanceOf[UTF8String].toString()))
val compiledPattern = ctx.addMutableState(patternClass, "compiledPattern",
v => s"""$v = $patternClass.compile("$regexStr");""")

// We don't use nullSafeCodeGen here because we don't want to re-evaluate right again.
val eval = left.genCode(ctx)
val eval = str.genCode(ctx)
ev.copy(code = code"""
${eval.code}
boolean ${ev.isNull} = ${eval.isNull};
${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
if (!${ev.isNull}) {
${ev.value} = $pattern.matcher(${eval.value}.toString()).matches();
${ev.value} = $compiledPattern.matcher(${eval.value}.toString()).matches();
}
""")
} else {
Expand All @@ -164,18 +185,18 @@ case class Like(left: Expression, right: Expression, escapeChar: Char)
} else {
escapeChar
}
val rightStr = ctx.freshName("rightStr")
val pattern = ctx.addMutableState(patternClass, "pattern")
val lastRightStr = ctx.addMutableState(classOf[String].getName, "lastRightStr")
val patternStr = ctx.freshName("patternStr")
val compiledPattern = ctx.addMutableState(patternClass, "compiledPattern")
val lastPatternStr = ctx.addMutableState(classOf[String].getName, "lastPatternStr")

nullSafeCodeGen(ctx, ev, (eval1, eval2) => {
nullSafeCodeGen(ctx, ev, (eval1, eval2, _) => {
s"""
String $rightStr = $eval2.toString();
if (!$rightStr.equals($lastRightStr)) {
$pattern = $patternClass.compile($escapeFunc($rightStr, '$newEscapeChar'));
$lastRightStr = $rightStr;
String $patternStr = $eval2.toString();
if (!$patternStr.equals($lastPatternStr)) {
$compiledPattern = $patternClass.compile($escapeFunc($patternStr, '$newEscapeChar'));
$lastPatternStr = $patternStr;
}
${ev.value} = $pattern.matcher($eval1.toString()).matches();
${ev.value} = $compiledPattern.matcher($eval1.toString()).matches();
"""
})
}
Expand Down Expand Up @@ -214,12 +235,20 @@ case class Like(left: Expression, right: Expression, escapeChar: Char)
""",
since = "1.0.0")
// scalastyle:on line.contains.tab
case class RLike(left: Expression, right: Expression) extends StringRegexExpression {
case class RLike(left: Expression, right: Expression)
extends BinaryExpression with StringRegexExpression {

override def inputTypes: Seq[DataType] = Seq(StringType, StringType)

override def str: Expression = left
override def pattern: Expression = right

override def escape(v: String): String = v
override def matches(regex: Pattern, str: String): Boolean = regex.matcher(str).find(0)
override def toString: String = s"$left RLIKE $right"

protected override def nullSafeEval(input1: Any, input2: Any): Any = nullSafeMatch(input1, input2)

override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
val patternClass = classOf[Pattern].getName

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1392,9 +1392,9 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging
throw new ParseException("Invalid escape string." +
"Escape string must contains only one character.", ctx)
}
str.charAt(0)
str
}.getOrElse('\\')
invertIfNotDefined(Like(e, expression(ctx.pattern), escapeChar))
invertIfNotDefined(Like(e, expression(ctx.pattern), Literal(escapeChar)))
case SqlBaseParser.RLIKE =>
invertIfNotDefined(RLike(e, expression(ctx.pattern)))
case SqlBaseParser.NULL if ctx.NOT != null =>
Expand Down