Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP: Add new regexp fuzz tests that run with unicode input #5573

Closed
wants to merge 22 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions docs/compatibility.md
Original file line number Diff line number Diff line change
Expand Up @@ -587,9 +587,8 @@ Here are some examples of regular expression patterns that are not supported on
- Line anchor `$` is not supported by `regexp_replace`, and in some rare contexts.
- String anchor `\Z` is not supported by `regexp_replace`, and in some rare contexts.
- String anchor `\z` is not supported by `regexp_replace`
- Line anchor `$` and string anchors `\z` and `\Z` are not supported in patterns containing `\W` or `\D`
- Line and string anchors are not supported by `string_split` and `str_to_map`
- Non-digit character class `\D`
- Non-word character class `\W`
- Word and non-word boundaries, `\b` and `\B`
- Whitespace and non-whitespace characters, `\s` and `\S`
- Lazy quantifiers, such as `a*?`
Expand Down
10 changes: 8 additions & 2 deletions integration_tests/src/main/python/string_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -951,21 +951,27 @@ def test_regexp_octal_digits():

def test_regexp_replace_digit():
gen = mk_str_gen('[a-z]{0,2}[0-9]{0,2}') \
.with_special_case('䤫畍킱곂⬡❽ࢅ獰᳌蛫青')
.with_special_case('䤫畍킱곂⬡❽ࢅ獰᳌蛫青') \
.with_special_case('a\n2\r\n3')
assert_gpu_and_cpu_are_equal_collect(
lambda spark: unary_op_df(spark, gen).selectExpr(
'regexp_replace(a, "\\\\d", "x")',
'regexp_replace(a, "\\\\D", "x")',
'regexp_replace(a, "[0-9]", "x")',
'regexp_replace(a, "[^0-9]", "x")',
),
conf=_regexp_conf)

def test_regexp_replace_word():
gen = mk_str_gen('[a-z]{0,2}[_]{0,1}[0-9]{0,2}') \
.with_special_case('䤫畍킱곂⬡❽ࢅ獰᳌蛫青')
.with_special_case('䤫畍킱곂⬡❽ࢅ獰᳌蛫青') \
.with_special_case('a\n2\r\n3')
assert_gpu_and_cpu_are_equal_collect(
lambda spark: unary_op_df(spark, gen).selectExpr(
'regexp_replace(a, "\\\\w", "x")',
'regexp_replace(a, "\\\\W", "x")',
'regexp_replace(a, "[a-zA-Z_0-9]", "x")',
'regexp_replace(a, "[^a-zA-Z_0-9]", "x")',
),
conf=_regexp_conf)

Expand Down
61 changes: 49 additions & 12 deletions sql-plugin/src/main/scala/com/nvidia/spark/rapids/RegexParser.scala
Original file line number Diff line number Diff line change
Expand Up @@ -559,7 +559,7 @@ class CudfRegexTranspiler(mode: RegexMode) {
val replacement = repl.map(s => new RegexParser(s).parseReplacement(countCaptureGroups(regex)))

// validate that the regex is supported by cuDF
val cudfRegex = rewrite(regex, replacement, None)
val cudfRegex = transpile(regex, replacement, None)
// write out to regex string, performing minor transformations
// such as adding additional escaping
(cudfRegex.toRegexString, replacement.map(_.toRegexString))
Expand Down Expand Up @@ -696,6 +696,30 @@ class CudfRegexTranspiler(mode: RegexMode) {
}
}

private def transpile(regex: RegexAST, replacement: Option[RegexReplacement],
previous: Option[RegexAST]): RegexAST = {

// look for patterns that we know are problematic before we attempt to rewrite the expression
val negatedWordOrDigit = contains(regex, {
case RegexEscaped('W') | RegexEscaped('D') => true
case _ => false
})
val endOfLineAnchor = contains(regex, {
case RegexChar('$') | RegexEscaped('Z') | RegexEscaped('z') => true
case _ => false
})

// this check is quite broad and could potentially be refined to look for \W or \D
// immediately next to a line anchor
if (negatedWordOrDigit && endOfLineAnchor) {
throw new RegexUnsupportedException(
"Combination of \\W or \\D with line anchor $ " +
"or string anchors \\z or \\Z is not supported")
}

rewrite(regex, replacement, previous)
}

private def rewrite(regex: RegexAST, replacement: Option[RegexReplacement],
previous: Option[RegexAST]): RegexAST = {
regex match {
Expand Down Expand Up @@ -792,25 +816,29 @@ class CudfRegexTranspiler(mode: RegexMode) {
}

case RegexEscaped(ch) => ch match {
case 'd' =>
case 'd' | 'D' =>
// cuDF is not compatible with Java for \d so we transpile to Java's definition
// of [0-9]
// https://github.com/rapidsai/cudf/issues/10894
RegexCharacterClass(negated = false, ListBuffer(RegexCharacterRange('0', '9')))
case 'w' =>
val components = ListBuffer[RegexCharacterClassComponent](RegexCharacterRange('0', '9'))
if (ch.isUpper) {
negateCharacterClass(components)
} else {
RegexCharacterClass(negated = false, components)
}
case 'w' | 'W' =>
// cuDF is not compatible with Java for \w so we transpile to Java's definition
// of `[a-zA-Z_0-9]`
RegexCharacterClass(negated = false, ListBuffer(
val components = ListBuffer[RegexCharacterClassComponent](
RegexCharacterRange('a', 'z'),
RegexCharacterRange('A', 'Z'),
RegexChar('_'),
RegexCharacterRange('0', '9')))
case 'D' =>
// see https://github.com/NVIDIA/spark-rapids/issues/4475
throw new RegexUnsupportedException("non-digit class \\D is not supported")
case 'W' =>
// see https://github.com/NVIDIA/spark-rapids/issues/4475
throw new RegexUnsupportedException("non-word class \\W is not supported")
RegexCharacterRange('0', '9'))
if (ch.isUpper) {
negateCharacterClass(components)
} else {
RegexCharacterClass(negated = false, components)
}
case 'b' | 'B' =>
// see https://github.com/NVIDIA/spark-rapids/issues/4517
throw new RegexUnsupportedException("word boundaries are not supported")
Expand Down Expand Up @@ -1162,6 +1190,15 @@ class CudfRegexTranspiler(mode: RegexMode) {
}
}

private def contains(regex: RegexAST, f: RegexAST => Boolean): Boolean = regex match {
case RegexSequence(parts) => parts.exists(x => contains(x, f))
case RegexGroup(_, term) => contains(term, f)
case RegexChoice(l, r) => contains(l, f) || contains(r, f)
case RegexRepetition(term, _) => contains(term, f)
case RegexCharacterClass(_, chars) => chars.exists(ch => contains(ch, f))
case leaf => f(leaf)
}

private def isBeginOrEndLineAnchor(regex: RegexAST): Boolean = regex match {
case RegexSequence(parts) => parts.nonEmpty && parts.forall(isBeginOrEndLineAnchor)
case RegexGroup(_, term) => isBeginOrEndLineAnchor(term)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,15 @@ class RegularExpressionTranspilerSuite extends FunSuite with Arm {
}
}

test("Detect unsupported combinations of line anchors and \\W and \\D") {
val patterns = Seq("\\W\\Z\\D", "\\W$", "$\\D")
patterns.foreach(pattern =>
assertUnsupported(pattern, RegexFindMode,
"Combination of \\W or \\D with line anchor $ " +
"or string anchors \\z or \\Z is not supported")
)
}

test("cuDF does not support choice with nothing to repeat") {
val patterns = Seq("b+|^\t")
patterns.foreach(pattern =>
Expand Down Expand Up @@ -383,20 +392,6 @@ class RegularExpressionTranspilerSuite extends FunSuite with Arm {
assertCpuGpuMatchesRegexpFind(patterns, inputs)
}

test("fall back to CPU for \\D") {
// see https://github.com/NVIDIA/spark-rapids/issues/4475
for (mode <- Seq(RegexFindMode, RegexReplaceMode)) {
assertUnsupported("\\D", mode, "non-digit class \\D is not supported")
}
}

test("fall back to CPU for \\W") {
// see https://github.com/NVIDIA/spark-rapids/issues/4475
for (mode <- Seq(RegexFindMode, RegexReplaceMode)) {
assertUnsupported("\\W", mode, "non-word class \\W is not supported")
}
}

test("compare CPU and GPU: replace digits") {
// note that we do not test with quantifiers `?` or `*` due
// to https://github.com/NVIDIA/spark-rapids/issues/4468
Expand Down Expand Up @@ -537,11 +532,23 @@ class RegularExpressionTranspilerSuite extends FunSuite with Arm {
}

test("AST fuzz test - regexp_find") {
doAstFuzzTest(Some(REGEXP_LIMITED_CHARS_FIND), RegexFindMode)
doAstFuzzTest(Some(REGEXP_LIMITED_CHARS_FIND), REGEXP_LIMITED_CHARS_FIND,
RegexFindMode)
}

test("AST fuzz test - regexp_replace") {
doAstFuzzTest(Some(REGEXP_LIMITED_CHARS_REPLACE), RegexReplaceMode)
doAstFuzzTest(Some(REGEXP_LIMITED_CHARS_REPLACE), REGEXP_LIMITED_CHARS_REPLACE,
RegexReplaceMode)
}

test("AST fuzz test - regexp_find - full unicode input") {
doAstFuzzTest(None, REGEXP_LIMITED_CHARS_FIND,
RegexFindMode)
}

test("AST fuzz test - regexp_replace - full unicode input") {
doAstFuzzTest(None, REGEXP_LIMITED_CHARS_REPLACE,
RegexReplaceMode)
}

test("string split - optimized") {
Expand Down Expand Up @@ -582,7 +589,7 @@ class RegularExpressionTranspilerSuite extends FunSuite with Arm {

test("string split fuzz") {
val (data, patterns) = generateDataAndPatterns(Some(REGEXP_LIMITED_CHARS_REPLACE),
RegexSplitMode)
REGEXP_LIMITED_CHARS_REPLACE, RegexSplitMode)
for (limit <- Seq(-2, -1, 2, 5)) {
doStringSplitTest(patterns, data, limit)
}
Expand Down Expand Up @@ -643,26 +650,29 @@ class RegularExpressionTranspilerSuite extends FunSuite with Arm {
}
}

private def doAstFuzzTest(validChars: Option[String], mode: RegexMode) {
val (data, patterns) = generateDataAndPatterns(validChars, mode)
private def doAstFuzzTest(validDataChars: Option[String], validPatternChars: String,
mode: RegexMode) {
val (data, patterns) = generateDataAndPatterns(validDataChars, validPatternChars, mode)
if (mode == RegexReplaceMode) {
assertCpuGpuMatchesRegexpReplace(patterns.toSeq, data)
} else {
assertCpuGpuMatchesRegexpFind(patterns.toSeq, data)
}
}

private def generateDataAndPatterns(validChars: Option[String], mode: RegexMode)
: (Seq[String], Set[String]) = {
val r = new EnhancedRandom(new Random(seed = 0L),
FuzzerOptions(validChars, maxStringLen = 12))
private def generateDataAndPatterns(
validDataChars: Option[String],
validPatternChars: String,
mode: RegexMode): (Seq[String], Set[String]) = {

val fuzzer = new FuzzRegExp(REGEXP_LIMITED_CHARS_FIND)
val dataGen = new EnhancedRandom(new Random(seed = 0L),
FuzzerOptions(validDataChars, maxStringLen = 12))

val data = Range(0, 1000)
.map(_ => r.nextString())
.map(_ => dataGen.nextString())

// generate patterns that are valid on both CPU and GPU
val fuzzer = new FuzzRegExp(validPatternChars)
val patterns = HashSet[String]()
while (patterns.size < 5000) {
val pattern = fuzzer.generate(0).toRegexString
Expand All @@ -676,7 +686,7 @@ class RegularExpressionTranspilerSuite extends FunSuite with Arm {
}

private def assertCpuGpuMatchesRegexpFind(javaPatterns: Seq[String], input: Seq[String]) = {
for (javaPattern <- javaPatterns) {
for ((javaPattern, patternIndex) <- javaPatterns.zipWithIndex) {
val cpu = cpuContains(javaPattern, input)
val (cudfPattern, _) =
new CudfRegexTranspiler(RegexFindMode).transpile(javaPattern, None)
Expand All @@ -688,7 +698,7 @@ class RegularExpressionTranspilerSuite extends FunSuite with Arm {
}
for (i <- input.indices) {
if (cpu(i) != gpu(i)) {
fail(s"javaPattern=${toReadableString(javaPattern)}, " +
fail(s"javaPattern[$patternIndex]=${toReadableString(javaPattern)}, " +
s"cudfPattern=${toReadableString(cudfPattern)}, " +
s"input='${toReadableString(input(i))}', " +
s"cpu=${cpu(i)}, gpu=${gpu(i)}")
Expand All @@ -700,7 +710,7 @@ class RegularExpressionTranspilerSuite extends FunSuite with Arm {
private def assertCpuGpuMatchesRegexpReplace(
javaPatterns: Seq[String],
input: Seq[String]) = {
for (javaPattern <- javaPatterns) {
for ((javaPattern, patternIndex) <- javaPatterns.zipWithIndex) {
val cpu = cpuReplace(javaPattern, input)
val (cudfPattern, replaceString) =
(new CudfRegexTranspiler(RegexReplaceMode)).transpile(javaPattern,
Expand All @@ -715,7 +725,7 @@ class RegularExpressionTranspilerSuite extends FunSuite with Arm {
}
for (i <- input.indices) {
if (cpu(i) != gpu(i)) {
fail(s"javaPattern=${toReadableString(javaPattern)}, " +
fail(s"javaPattern[$patternIndex]=${toReadableString(javaPattern)}, " +
s"cudfPattern=${toReadableString(cudfPattern)}, " +
s"input='${toReadableString(input(i))}', " +
s"cpu=${toReadableString(cpu(i))}, " +
Expand Down