diff --git a/docs/compatibility.md b/docs/compatibility.md index 337633b6d99..0f215973a0a 100644 --- a/docs/compatibility.md +++ b/docs/compatibility.md @@ -574,15 +574,14 @@ The following Apache Spark regular expression functions and expressions are supp - `string_split` - `str_to_map` -Regular expression evaluation on the GPU is enabled by default. Execution will fall back to the CPU for -regular expressions that are not yet supported on the GPU. However, there are some edge cases that will -still execute on the GPU and produce different results to the CPU. To disable regular expressions on the GPU, -set `spark.rapids.sql.regexp.enabled=false`. +Regular expression evaluation on the GPU is enabled by default when the UTF-8 character set is used +by the current locale. Execution will fall back to the CPU for regular expressions that are not yet +supported on the GPU, and in environments where the locale does not use UTF-8. However, there are +some edge cases that will still execute on the GPU and produce different results to the CPU. To +disable regular expressions on the GPU, set `spark.rapids.sql.regexp.enabled=false`. These are the known edge cases where running on the GPU will produce different results to the CPU: -- Using regular expressions with Unicode data can produce incorrect results if the system `LANG` is not set - to `en_US.UTF-8` ([#5549](https://github.com/NVIDIA/spark-rapids/issues/5549)) - Regular expressions that contain an end of line anchor '$' or end of string anchor '\Z' or '\z' immediately next to a newline or a repetition that produces zero or more results ([#5610](https://github.com/NVIDIA/spark-rapids/pull/5610))` @@ -596,7 +595,6 @@ The following regular expression patterns are not yet supported on the GPU and w or more results - Line anchor `$` and string anchors `\z` and `\Z` are not supported in patterns containing `\W` or `\D` - Line and string anchors are not supported by `string_split` and `str_to_map` -- Word and non-word boundaries, `\b` and `\B` - Lazy quantifiers, such as `a*?` - Possessive quantifiers, such as `a*+` - Character classes that use union, intersection, or subtraction semantics, such as `[a-d[m-p]]`, `[a-z&&[def]]`, @@ -604,6 +602,12 @@ The following regular expression patterns are not yet supported on the GPU and w - Empty groups: `()` - `regexp_replace` does not support back-references +The following regular expression patterns are known to potentially produce different results on the GPU +vs the CPU. + +- Word and non-word boundaries, `\b` and `\B` + + Work is ongoing to increase the range of regular expressions that can run on the GPU. ## Timestamps diff --git a/integration_tests/pytest.ini b/integration_tests/pytest.ini index 60f8894160d..f4d9793c5c0 100644 --- a/integration_tests/pytest.ini +++ b/integration_tests/pytest.ini @@ -30,5 +30,6 @@ markers = nightly_host_mem_consuming_case: case in nightly_resource_consuming_test that consume much more host memory than normal cases fuzz_test: Mark fuzz tests iceberg: Mark a test that requires Iceberg has been configured, skipping if tests are not configured for Iceberg + regexp: Mark a test that tests regular expressions on the GPU (only works when UTF-8 is enabled) filterwarnings = ignore:.*pytest.mark.order.*:_pytest.warning_types.PytestUnknownMarkWarning diff --git a/integration_tests/src/main/python/regexp_no_unicode_test.py b/integration_tests/src/main/python/regexp_no_unicode_test.py new file mode 100644 index 00000000000..230c06f4d3f --- /dev/null +++ b/integration_tests/src/main/python/regexp_no_unicode_test.py @@ -0,0 +1,57 @@ +# Copyright (c) 2022, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import locale +import pytest + +from asserts import assert_gpu_fallback_collect +from data_gen import * +from marks import * +from pyspark.sql.types import * + +if locale.nl_langinfo(locale.CODESET) == 'UTF-8': + pytestmark = pytest.mark.skip(reason=str("Current locale uses UTF-8, fallback will not occur")) + +_regexp_conf = { 'spark.rapids.sql.regexp.enabled': 'true' } + +def mk_str_gen(pattern): + return StringGen(pattern).with_special_case('').with_special_pattern('.{0,10}') + +@allow_non_gpu('ProjectExec', 'RLike') +def test_rlike_no_unicode_fallback(): + gen = mk_str_gen('[abcd]{1,3}') + assert_gpu_fallback_collect( + lambda spark: unary_op_df(spark, gen).selectExpr( + 'a rlike "ab"'), + 'RLike', + conf=_regexp_conf) + +@allow_non_gpu('ProjectExec', 'RegExpReplace') +def test_re_replace_no_unicode_fallback(): + gen = mk_str_gen('.{0,5}TEST[\ud720 A]{0,5}') + assert_gpu_fallback_collect( + lambda spark: unary_op_df(spark, gen).selectExpr( + 'REGEXP_REPLACE(a, "TEST", "PROD")'), + 'RegExpReplace', + conf=_regexp_conf) + +@allow_non_gpu('ProjectExec', 'StringSplit') +def test_split_re_no_unicode_fallback(): + data_gen = mk_str_gen('([bf]o{0,2}:){1,7}') \ + .with_special_case('boo:and:foo') + assert_gpu_fallback_collect( + lambda spark : unary_op_df(spark, data_gen).selectExpr( + 'split(a, "[o]", 2)'), + 'StringSplit', + conf=_regexp_conf) diff --git a/integration_tests/src/main/python/regexp_test.py b/integration_tests/src/main/python/regexp_test.py new file mode 100644 index 00000000000..d8988840b32 --- /dev/null +++ b/integration_tests/src/main/python/regexp_test.py @@ -0,0 +1,764 @@ +# Copyright (c) 2022, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import locale +import pytest + +from asserts import assert_gpu_and_cpu_are_equal_collect, assert_gpu_fallback_collect, \ + assert_cpu_and_gpu_are_equal_collect_with_capture, assert_gpu_and_cpu_error, \ + assert_gpu_sql_fallback_collect +from data_gen import * +from marks import * +from pyspark.sql.types import * +from spark_session import is_before_spark_320 + +if locale.nl_langinfo(locale.CODESET) != 'UTF-8': + pytestmark = [pytest.mark.regexp, pytest.mark.skip(reason=str("Current locale doesn't support UTF-8, regexp support is disabled"))] +else: + pytestmark = pytest.mark.regexp + +_regexp_conf = { 'spark.rapids.sql.regexp.enabled': 'true' } + +def mk_str_gen(pattern): + return StringGen(pattern).with_special_case('').with_special_pattern('.{0,10}') + +def test_split_re_negative_limit(): + data_gen = mk_str_gen('([bf]o{0,2}:){1,7}') \ + .with_special_case('boo:and:foo') + assert_gpu_and_cpu_are_equal_collect( + lambda spark : unary_op_df(spark, data_gen).selectExpr( + 'split(a, "[:]", -1)', + 'split(a, "[o:]", -1)', + 'split(a, "[^:]", -1)', + 'split(a, "[^o]", -1)', + 'split(a, "[o]{1,2}", -1)', + 'split(a, "[bf]", -1)', + 'split(a, "[o]", -2)'), + conf=_regexp_conf) + +def test_split_re_zero_limit(): + data_gen = mk_str_gen('([bf]o{0,2}:){1,7}') \ + .with_special_case('boo:and:foo') + assert_gpu_and_cpu_are_equal_collect( + lambda spark : unary_op_df(spark, data_gen).selectExpr( + 'split(a, "[:]", 0)', + 'split(a, "[o:]", 0)', + 'split(a, "[^:]", 0)', + 'split(a, "[^o]", 0)', + 'split(a, "[o]{1,2}", 0)', + 'split(a, "[bf]", 0)', + 'split(a, "[o]", 0)'), + conf=_regexp_conf) + +def test_split_re_one_limit(): + data_gen = mk_str_gen('([bf]o{0,2}:){1,7}') \ + .with_special_case('boo:and:foo') + assert_gpu_and_cpu_are_equal_collect( + lambda spark : unary_op_df(spark, data_gen).selectExpr( + 'split(a, "[:]", 1)', + 'split(a, "[o:]", 1)', + 'split(a, "[^:]", 1)', + 'split(a, "[^o]", 1)', + 'split(a, "[o]{1,2}", 1)', + 'split(a, "[bf]", 1)', + 'split(a, "[o]", 1)'), + conf=_regexp_conf) + +def test_split_re_positive_limit(): + data_gen = mk_str_gen('([bf]o{0,2}:){1,7}') \ + .with_special_case('boo:and:foo') + assert_gpu_and_cpu_are_equal_collect( + lambda spark : unary_op_df(spark, data_gen).selectExpr( + 'split(a, "[:]", 2)', + 'split(a, "[o:]", 5)', + 'split(a, "[^:]", 2)', + 'split(a, "[^o]", 55)', + 'split(a, "[o]{1,2}", 999)', + 'split(a, "[bf]", 2)', + 'split(a, "[o]", 5)'), + conf=_regexp_conf) + +def test_split_re_no_limit(): + data_gen = mk_str_gen('([bf]o{0,2}:){1,7}') \ + .with_special_case('boo:and:foo') + assert_gpu_and_cpu_are_equal_collect( + lambda spark : unary_op_df(spark, data_gen).selectExpr( + 'split(a, "[:]")', + 'split(a, "[o:]")', + 'split(a, "[^:]")', + 'split(a, "[^o]")', + 'split(a, "[o]{1,2}")', + 'split(a, "[bf]")', + 'split(a, "[o]")'), + conf=_regexp_conf) + +def test_split_optimized_no_re(): + data_gen = mk_str_gen('([bf]o{0,2}[.?+\\^$|{}]{1,2}){1,7}') \ + .with_special_case('boo.and.foo') \ + .with_special_case('boo?and?foo') \ + .with_special_case('boo+and+foo') \ + .with_special_case('boo^and^foo') \ + .with_special_case('boo$and$foo') \ + .with_special_case('boo|and|foo') \ + .with_special_case('boo{and}foo') \ + .with_special_case('boo$|and$|foo') + assert_gpu_and_cpu_are_equal_collect( + lambda spark : unary_op_df(spark, data_gen).selectExpr( + 'split(a, "\\\\.")', + 'split(a, "\\\\?")', + 'split(a, "\\\\+")', + 'split(a, "\\\\^")', + 'split(a, "\\\\$")', + 'split(a, "\\\\|")', + 'split(a, "\\\\{")', + 'split(a, "\\\\}")', + 'split(a, "\\\\$\\\\|")'), + conf=_regexp_conf) + +def test_split_optimized_no_re_combined(): + data_gen = mk_str_gen('([bf]o{0,2}[AZ.?+\\^$|{}]{1,2}){1,7}') \ + .with_special_case('booA.ZandA.Zfoo') \ + .with_special_case('booA?ZandA?Zfoo') \ + .with_special_case('booA+ZandA+Zfoo') \ + .with_special_case('booA^ZandA^Zfoo') \ + .with_special_case('booA$ZandA$Zfoo') \ + .with_special_case('booA|ZandA|Zfoo') \ + .with_special_case('boo{Zand}Zfoo') + assert_gpu_and_cpu_are_equal_collect( + lambda spark : unary_op_df(spark, data_gen).selectExpr( + 'split(a, "A\\\\.Z")', + 'split(a, "A\\\\?Z")', + 'split(a, "A\\\\+Z")', + 'split(a, "A\\\\^Z")', + 'split(a, "A\\\\$Z")', + 'split(a, "A\\\\|Z")', + 'split(a, "\\\\{Z")', + 'split(a, "\\\\}Z")'), + conf=_regexp_conf) + +def test_split_regexp_disabled_no_fallback(): + conf = { 'spark.rapids.sql.regexp.enabled': 'false' } + data_gen = mk_str_gen('([bf]o{0,2}[.?+\\^$|&_]{1,2}){1,7}') \ + .with_special_case('boo.and.foo') \ + .with_special_case('boo?and?foo') \ + .with_special_case('boo+and+foo') \ + .with_special_case('boo^and^foo') \ + .with_special_case('boo$and$foo') \ + .with_special_case('boo|and|foo') \ + .with_special_case('boo&and&foo') \ + .with_special_case('boo_and_foo') + assert_gpu_and_cpu_are_equal_collect( + lambda spark : unary_op_df(spark, data_gen).selectExpr( + 'split(a, "\\\\.")', + 'split(a, "\\\\?")', + 'split(a, "\\\\+")', + 'split(a, "\\\\^")', + 'split(a, "\\\\$")', + 'split(a, "\\\\|")', + 'split(a, "&")', + 'split(a, "_")', + ), conf + ) + +@allow_non_gpu('ProjectExec', 'StringSplit') +def test_split_regexp_disabled_fallback(): + conf = { 'spark.rapids.sql.regexp.enabled': 'false' } + data_gen = mk_str_gen('([bf]o{0,2}:){1,7}') \ + .with_special_case('boo:and:foo') + assert_gpu_sql_fallback_collect( + lambda spark : unary_op_df(spark, data_gen), + 'StringSplit', + 'string_split_table', + 'select ' + + 'split(a, "[:]", 2), ' + + 'split(a, "[o:]", 5), ' + + 'split(a, "[^:]", 2), ' + + 'split(a, "[^o]", 55), ' + + 'split(a, "[o]{1,2}", 999), ' + + 'split(a, "[bf]", 2), ' + + 'split(a, "[o]", 5) from string_split_table', + conf) + + +def test_re_replace(): + gen = mk_str_gen('.{0,5}TEST[\ud720 A]{0,5}') + assert_gpu_and_cpu_are_equal_collect( + lambda spark: unary_op_df(spark, gen).selectExpr( + 'REGEXP_REPLACE(a, "TEST", "PROD")', + 'REGEXP_REPLACE(a, "^TEST", "PROD")', + 'REGEXP_REPLACE(a, "^TEST\\z", "PROD")', + 'REGEXP_REPLACE(a, "TEST\\z", "PROD")', + 'REGEXP_REPLACE(a, "\\zTEST", "PROD")', + 'REGEXP_REPLACE(a, "TEST\\z", "PROD")', + 'REGEXP_REPLACE(a, "\\^TEST\\z", "PROD")', + 'REGEXP_REPLACE(a, "\\^TEST\\z", "PROD")', + 'REGEXP_REPLACE(a, "TEST", "")', + 'REGEXP_REPLACE(a, "TEST", "%^[]\ud720")', + 'REGEXP_REPLACE(a, "TEST", NULL)'), + conf=_regexp_conf) + +# We have shims to support empty strings for zero-repetition patterns +# See https://github.com/NVIDIA/spark-rapids/issues/5456 +def test_re_replace_repetition(): + gen = mk_str_gen('.{0,5}TEST[\ud720 A]{0,5}') + assert_gpu_and_cpu_are_equal_collect( + lambda spark: unary_op_df(spark, gen).selectExpr( + 'REGEXP_REPLACE(a, "[E]+", "PROD")', + 'REGEXP_REPLACE(a, "[A]+", "PROD")', + 'REGEXP_REPLACE(a, "A{0,}", "PROD")', + 'REGEXP_REPLACE(a, "T?E?", "PROD")', + 'REGEXP_REPLACE(a, "A*", "PROD")', + 'REGEXP_REPLACE(a, "A{0,5}", "PROD")', + 'REGEXP_REPLACE(a, "(A*)", "PROD")', + 'REGEXP_REPLACE(a, "(((A*)))", "PROD")', + 'REGEXP_REPLACE(a, "((A*)E?)", "PROD")', + 'REGEXP_REPLACE(a, "[A-Z]?", "PROD")' + ), + conf=_regexp_conf) + + +@allow_non_gpu('ProjectExec', 'RegExpReplace') +def test_re_replace_issue_5492(): + # https://github.com/NVIDIA/spark-rapids/issues/5492 + gen = mk_str_gen('.{0,5}TEST[\ud720 A]{0,5}') + assert_gpu_fallback_collect( + lambda spark: unary_op_df(spark, gen).selectExpr( + 'REGEXP_REPLACE(a, "[^\\\\sa-zA-Z0-9]", "x")'), + 'RegExpReplace', + conf=_regexp_conf) + +def test_re_replace_backrefs(): + gen = mk_str_gen('.{0,5}TEST[\ud720 A]{0,5}TEST') + assert_gpu_and_cpu_are_equal_collect( + lambda spark: unary_op_df(spark, gen).selectExpr( + 'REGEXP_REPLACE(a, "(TEST)", "$1")', + 'REGEXP_REPLACE(a, "(TEST)", "[$0]")', + 'REGEXP_REPLACE(a, "(TEST)", "[\\1]")', + 'REGEXP_REPLACE(a, "(T)[a-z]+(T)", "[$2][$1][$0]")', + 'REGEXP_REPLACE(a, "([0-9]+)(T)[a-z]+(T)", "[$3][$2][$1]")', + 'REGEXP_REPLACE(a, "(.)([0-9]+TEST)", "$0 $1 $2")', + 'REGEXP_REPLACE(a, "(TESTT)", "\\0 \\1")' # no match + ), + conf=_regexp_conf) + +def test_re_replace_anchors(): + gen = mk_str_gen('.{0,2}TEST[\ud720 A]{0,5}TEST[\r\n\u0085\u2028\u2029]?') \ + .with_special_case("TEST") \ + .with_special_case("TEST\n") \ + .with_special_case("TEST\r\n") \ + .with_special_case("TEST\r") + assert_gpu_and_cpu_are_equal_collect( + lambda spark: unary_op_df(spark, gen).selectExpr( + 'REGEXP_REPLACE(a, "TEST$", "")', + 'REGEXP_REPLACE(a, "TEST$", "PROD")', + 'REGEXP_REPLACE(a, "\ud720[A-Z]+$", "PROD")', + 'REGEXP_REPLACE(a, "(\ud720[A-Z]+)$", "PROD")', + 'REGEXP_REPLACE(a, "(TEST)$", "$1")', + 'REGEXP_REPLACE(a, "^(TEST)$", "$1")', + 'REGEXP_REPLACE(a, "\\\\ATEST\\\\Z", "PROD")', + 'REGEXP_REPLACE(a, "\\\\ATEST$", "PROD")', + 'REGEXP_REPLACE(a, "^TEST\\\\Z", "PROD")', + 'REGEXP_REPLACE(a, "TEST\\\\Z", "PROD")', + 'REGEXP_REPLACE(a, "TEST\\\\z", "PROD")', + 'REGEXP_REPLACE(a, "\\\\zTEST", "PROD")', + 'REGEXP_REPLACE(a, "^TEST$", "PROD")', + 'REGEXP_REPLACE(a, "^TEST\\\\z", "PROD")', + 'REGEXP_REPLACE(a, "TEST\\\\z", "PROD")', + ), + conf=_regexp_conf) + +# For GPU runs, cuDF will check the range and throw exception if index is out of range +def test_re_replace_backrefs_idx_out_of_bounds(): + gen = mk_str_gen('.{0,5}TEST[\ud720 A]{0,5}') + assert_gpu_and_cpu_error(lambda spark: unary_op_df(spark, gen).selectExpr( + 'REGEXP_REPLACE(a, "(T)(E)(S)(T)", "[$5]")').collect(), + conf=_regexp_conf, + error_message='') + +def test_re_replace_backrefs_escaped(): + gen = mk_str_gen('.{0,5}TEST[\ud720 A]{0,5}') + assert_gpu_and_cpu_are_equal_collect( + lambda spark: unary_op_df(spark, gen).selectExpr( + 'REGEXP_REPLACE(a, "(TEST)", "[\\\\$0]")', + 'REGEXP_REPLACE(a, "(TEST)", "[\\\\$1]")'), + conf=_regexp_conf) + +def test_re_replace_escaped(): + gen = mk_str_gen('.{0,5}TEST[\ud720 A]{0,5}') + assert_gpu_and_cpu_are_equal_collect( + lambda spark: unary_op_df(spark, gen).selectExpr( + 'REGEXP_REPLACE(a, "[A-Z]+", "\\\\A\\A\\\\t\\\\r\\\\n\\t\\r\\n")'), + conf=_regexp_conf) + +def test_re_replace_null(): + gen = mk_str_gen('[\u0000 ]{0,2}TE[\u0000 ]{0,2}ST[\u0000 ]{0,2}')\ + .with_special_case("\u0000")\ + .with_special_case("\u0000\u0000") + assert_gpu_and_cpu_are_equal_collect( + lambda spark: unary_op_df(spark, gen).selectExpr( + 'REGEXP_REPLACE(a, "\u0000", "")', + 'REGEXP_REPLACE(a, "\000", "")', + 'REGEXP_REPLACE(a, "\00", "")', + 'REGEXP_REPLACE(a, "\x00", "")', + 'REGEXP_REPLACE(a, "\0", "")', + 'REGEXP_REPLACE(a, "\u0000", "NULL")', + 'REGEXP_REPLACE(a, "\000", "NULL")', + 'REGEXP_REPLACE(a, "\00", "NULL")', + 'REGEXP_REPLACE(a, "\x00", "NULL")', + 'REGEXP_REPLACE(a, "\0", "NULL")', + 'REGEXP_REPLACE(a, "TE\u0000ST", "PROD")', + 'REGEXP_REPLACE(a, "TE\u0000\u0000ST", "PROD")', + 'REGEXP_REPLACE(a, "[\x00TEST]", "PROD")', + 'REGEXP_REPLACE(a, "[TE\00ST]", "PROD")', + 'REGEXP_REPLACE(a, "[\u0000-z]", "PROD")'), + conf=_regexp_conf) + +def test_regexp_replace(): + gen = mk_str_gen('[abcd]{0,3}') + assert_gpu_and_cpu_are_equal_collect( + lambda spark: unary_op_df(spark, gen).selectExpr( + 'regexp_replace(a, "a", "A")', + 'regexp_replace(a, "[^xyz]", "A")', + 'regexp_replace(a, "([^x])|([^y])", "A")', + 'regexp_replace(a, "(?:aa)+", "A")', + 'regexp_replace(a, "a|b|c", "A")'), + conf=_regexp_conf) + +@pytest.mark.skipif(is_before_spark_320(), reason='regexp is synonym for RLike starting in Spark 3.2.0') +def test_regexp(): + gen = mk_str_gen('[abcd]{1,3}') + assert_gpu_and_cpu_are_equal_collect( + lambda spark: unary_op_df(spark, gen).selectExpr( + 'regexp(a, "a{2}")', + 'regexp(a, "a{1,3}")', + 'regexp(a, "a{1,}")', + 'regexp(a, "a[bc]d")'), + conf=_regexp_conf) + +@pytest.mark.skipif(is_before_spark_320(), reason='regexp_like is synonym for RLike starting in Spark 3.2.0') +def test_regexp_like(): + gen = mk_str_gen('[abcd]{1,3}') + assert_gpu_and_cpu_are_equal_collect( + lambda spark: unary_op_df(spark, gen).selectExpr( + 'regexp_like(a, "a{2}")', + 'regexp_like(a, "a{1,3}")', + 'regexp_like(a, "a{1,}")', + 'regexp_like(a, "a[bc]d")'), + conf=_regexp_conf) + +def test_regexp_replace_character_set_negated(): + gen = mk_str_gen('[abcd]{0,3}[\r\n]{0,2}[abcd]{0,3}') + assert_gpu_and_cpu_are_equal_collect( + lambda spark: unary_op_df(spark, gen).selectExpr( + 'regexp_replace(a, "([^a])|([^b])", "1")', + 'regexp_replace(a, "[^a]", "1")', + 'regexp_replace(a, "([^a]|[\r\n])", "1")', + 'regexp_replace(a, "[^a\r\n]", "1")', + 'regexp_replace(a, "[^a\r]", "1")', + 'regexp_replace(a, "[^a\n]", "1")', + 'regexp_replace(a, "[^\r\n]", "1")', + 'regexp_replace(a, "[^\r]", "1")', + 'regexp_replace(a, "[^\n]", "1")'), + conf=_regexp_conf) + +def test_regexp_extract(): + gen = mk_str_gen('[abcd]{1,3}[0-9]{1,3}[abcd]{1,3}') + assert_gpu_and_cpu_are_equal_collect( + lambda spark: unary_op_df(spark, gen).selectExpr( + 'regexp_extract(a, "([0-9]+)", 1)', + 'regexp_extract(a, "([0-9])([abcd]+)", 1)', + 'regexp_extract(a, "([0-9])([abcd]+)", 2)', + 'regexp_extract(a, "^([a-d]*)([0-9]*)([a-d]*)\\z", 1)', + 'regexp_extract(a, "^([a-d]*)([0-9]*)([a-d]*)\\z", 2)', + 'regexp_extract(a, "^([a-d]*)([0-9]*)([a-d]*)\\z", 3)'), + conf=_regexp_conf) + +def test_regexp_extract_no_match(): + gen = mk_str_gen('[abcd]{1,3}[0-9]{1,3}[abcd]{1,3}') + assert_gpu_and_cpu_are_equal_collect( + lambda spark: unary_op_df(spark, gen).selectExpr( + 'regexp_extract(a, "^([0-9]+)([a-z]+)([0-9]+)\\z", 0)', + 'regexp_extract(a, "^([0-9]+)([a-z]+)([0-9]+)\\z", 1)', + 'regexp_extract(a, "^([0-9]+)([a-z]+)([0-9]+)\\z", 2)', + 'regexp_extract(a, "^([0-9]+)([a-z]+)([0-9]+)\\z", 3)'), + conf=_regexp_conf) + +# if we determine that the index is out of range we fall back to CPU and let +# Spark take care of the error handling +@allow_non_gpu('ProjectExec', 'RegExpExtract') +def test_regexp_extract_idx_negative(): + gen = mk_str_gen('[abcd]{1,3}[0-9]{1,3}[abcd]{1,3}') + assert_gpu_and_cpu_error( + lambda spark: unary_op_df(spark, gen).selectExpr( + 'regexp_extract(a, "^([a-d]*)([0-9]*)([a-d]*)$", -1)').collect(), + error_message = "The specified group index cannot be less than zero", + conf=_regexp_conf) + +# if we determine that the index is out of range we fall back to CPU and let +# Spark take care of the error handling +@allow_non_gpu('ProjectExec', 'RegExpExtract') +def test_regexp_extract_idx_out_of_bounds(): + gen = mk_str_gen('[abcd]{1,3}[0-9]{1,3}[abcd]{1,3}') + assert_gpu_and_cpu_error( + lambda spark: unary_op_df(spark, gen).selectExpr( + 'regexp_extract(a, "^([a-d]*)([0-9]*)([a-d]*)$", 4)').collect(), + error_message = "Regex group count is 3, but the specified group index is 4", + conf=_regexp_conf) + +def test_regexp_extract_multiline(): + gen = mk_str_gen('[abcd]{2}[\r\n]{0,2}[0-9]{2}[\r\n]{0,2}[abcd]{2}') + assert_gpu_and_cpu_are_equal_collect( + lambda spark: unary_op_df(spark, gen).selectExpr( + 'regexp_extract(a, "^([a-d]*)([\r\n]*)", 2)'), + conf=_regexp_conf) + +def test_regexp_extract_multiline_negated_character_class(): + gen = mk_str_gen('[abcd]{2}[\r\n]{0,2}[0-9]{2}[\r\n]{0,2}[abcd]{2}') + assert_gpu_and_cpu_are_equal_collect( + lambda spark: unary_op_df(spark, gen).selectExpr( + 'regexp_extract(a, "^([a-d]*)([^a-z]*)([a-d]*)\\z", 2)'), + conf=_regexp_conf) + +def test_regexp_extract_idx_0(): + gen = mk_str_gen('[abcd]{1,3}[0-9]{1,3}[abcd]{1,3}') + assert_gpu_and_cpu_are_equal_collect( + lambda spark: unary_op_df(spark, gen).selectExpr( + 'regexp_extract(a, "([0-9]+)[abcd]([abcd]+)", 0)', + 'regexp_extract(a, "^([a-d]*)([0-9]*)([a-d]*)\\z", 0)', + 'regexp_extract(a, "^([a-d]*)[0-9]*([a-d]*)\\z", 0)'), + conf=_regexp_conf) + +def test_word_boundaries(): + gen = StringGen('([abc]{1,3}[\r\n\t \f]{0,2}[123]){1,5}') + assert_gpu_and_cpu_are_equal_collect( + lambda spark: unary_op_df(spark, gen).selectExpr( + 'rlike(a, "\\\\b")', + 'rlike(a, "\\\\B")', + 'rlike(a, "\\\\b\\\\B")', + 'regexp_extract(a, "([a-d]+)\\\\b([e-h]+)", 1)', + 'regexp_extract(a, "([a-d]+)\\\\B", 1)', + 'regexp_replace(a, "\\\\b", "#")', + 'regexp_replace(a, "\\\\B", "#")', + ), + conf=_regexp_conf) + +def test_character_classes(): + gen = mk_str_gen('[abcd]{1,3}[0-9]{1,3}[abcd]{1,3}[ \n\t\r]{0,2}') + assert_gpu_and_cpu_are_equal_collect( + lambda spark: unary_op_df(spark, gen).selectExpr( + 'rlike(a, "[abcd]")', + 'rlike(a, "[^\n\r]")', + 'rlike(a, "[\n-\\]")', + 'rlike(a, "[+--]")', + 'regexp_extract(a, "[123]", 0)', + 'regexp_replace(a, "[\\\\0101-\\\\0132]", "@")', + 'regexp_replace(a, "[\\\\x41-\\\\x5a]", "@")', + ), + conf=_regexp_conf) + +def test_regexp_hexadecimal_digits(): + gen = mk_str_gen( + '[abcd]\\\\x00\\\\x7f\\\\x80\\\\xff\\\\x{10ffff}\\\\x{00eeee}[\\\\xa0-\\\\xb0][abcd]') + assert_gpu_and_cpu_are_equal_collect( + lambda spark: unary_op_df(spark, gen).selectExpr( + 'rlike(a, "\\\\x7f")', + 'rlike(a, "\\\\x80")', + 'rlike(a, "[\\\\xa0-\\\\xf0]")', + 'rlike(a, "\\\\x{00eeee}")', + 'regexp_extract(a, "([a-d]+)\\\\xa0([a-d]+)", 1)', + 'regexp_extract(a, "([a-d]+)[\\\\xa0\nabcd]([a-d]+)", 1)', + 'regexp_replace(a, "\\\\xff", "@")', + 'regexp_replace(a, "[\\\\xa0-\\\\xb0]", "@")', + 'regexp_replace(a, "\\\\x{10ffff}", "@")', + ), + conf=_regexp_conf) + +def test_regexp_whitespace(): + gen = mk_str_gen('\u001e[abcd]\t\n{1,3} [0-9]\n {1,3}\x0b\t[abcd]\r\f[0-9]{0,10}') + assert_gpu_and_cpu_are_equal_collect( + lambda spark: unary_op_df(spark, gen).selectExpr( + 'rlike(a, "\\\\s")', + 'rlike(a, "\\\\s{3}")', + 'rlike(a, "[abcd]+\\\\s+[0-9]+")', + 'rlike(a, "\\\\S{3}")', + 'rlike(a, "[abcd]+\\\\s+\\\\S{2,3}")', + 'regexp_extract(a, "([a-d]+)(\\\\s[0-9]+)([a-d]+)", 2)', + 'regexp_extract(a, "([a-d]+)(\\\\S+)([0-9]+)", 2)', + 'regexp_extract(a, "([a-d]+)(\\\\S+)([0-9]+)", 3)', + 'regexp_replace(a, "(\\\\s+)", "@")', + 'regexp_replace(a, "(\\\\S+)", "#")', + ), + conf=_regexp_conf) + +def test_regexp_horizontal_vertical_whitespace(): + gen = mk_str_gen( + '''\xA0\u1680\u180e[abcd]\t\n{1,3} [0-9]\n {1,3}\x0b\t[abcd]\r\f[0-9]{0,10} + [\u2001-\u200a]{1,3}\u202f\u205f\u3000\x85\u2028\u2029 + ''') + assert_gpu_and_cpu_are_equal_collect( + lambda spark: unary_op_df(spark, gen).selectExpr( + 'rlike(a, "\\\\h{2}")', + 'rlike(a, "\\\\v{3}")', + 'rlike(a, "[abcd]+\\\\h+[0-9]+")', + 'rlike(a, "[abcd]+\\\\v+[0-9]+")', + 'rlike(a, "\\\\H")', + 'rlike(a, "\\\\V")', + 'rlike(a, "[abcd]+\\\\h+\\\\V{2,3}")', + 'regexp_extract(a, "([a-d]+)([0-9]+\\\\v)([a-d]+)", 2)', + 'regexp_extract(a, "([a-d]+)(\\\\H+)([0-9]+)", 2)', + 'regexp_extract(a, "([a-d]+)(\\\\V+)([0-9]+)", 3)', + 'regexp_replace(a, "(\\\\v+)", "@")', + 'regexp_replace(a, "(\\\\H+)", "#")', + ), + conf=_regexp_conf) + +def test_regexp_linebreak(): + gen = mk_str_gen( + '[abc]{1,3}\u000D\u000A[def]{1,3}[\u000A\u000B\u000C\u000D\u0085\u2028\u2029]{0,5}[123]') + assert_gpu_and_cpu_are_equal_collect( + lambda spark: unary_op_df(spark, gen).selectExpr( + 'rlike(a, "\\\\R")', + 'regexp_extract(a, "([a-d]+)(\\\\R)([a-d]+)", 1)', + 'regexp_replace(a, "\\\\R", "")', + ), + conf=_regexp_conf) + +def test_regexp_octal_digits(): + gen = mk_str_gen('[abcd]\u0000\u0041\u007f\u0080\u00ff[\\\\xa0-\\\\xb0][abcd]') + assert_gpu_and_cpu_are_equal_collect( + lambda spark: unary_op_df(spark, gen).selectExpr( + 'rlike(a, "\\\\0177")', + 'rlike(a, "\\\\0200")', + 'rlike(a, "\\\\0101")', + 'rlike(a, "[\\\\0240-\\\\0377]")', + 'regexp_extract(a, "([a-d]+)\\\\0240([a-d]+)", 1)', + 'regexp_extract(a, "([a-d]+)[\\\\0141-\\\\0172]([a-d]+)", 0)', + 'regexp_replace(a, "\\\\0377", "")', + 'regexp_replace(a, "\\\\0260", "")', + ), + conf=_regexp_conf) + +def test_regexp_replace_digit(): + gen = mk_str_gen('[a-z]{0,2}[0-9]{0,2}') \ + .with_special_case('䤫畍킱곂⬡❽ࢅ獰᳌蛫青') \ + .with_special_case('a\n2\r\n3') + assert_gpu_and_cpu_are_equal_collect( + lambda spark: unary_op_df(spark, gen).selectExpr( + 'regexp_replace(a, "\\\\d", "x")', + 'regexp_replace(a, "\\\\D", "x")', + 'regexp_replace(a, "[0-9]", "x")', + 'regexp_replace(a, "[^0-9]", "x")', + ), + conf=_regexp_conf) + +def test_regexp_replace_word(): + gen = mk_str_gen('[a-z]{0,2}[_]{0,1}[0-9]{0,2}') \ + .with_special_case('䤫畍킱곂⬡❽ࢅ獰᳌蛫青') \ + .with_special_case('a\n2\r\n3') + assert_gpu_and_cpu_are_equal_collect( + lambda spark: unary_op_df(spark, gen).selectExpr( + 'regexp_replace(a, "\\\\w", "x")', + 'regexp_replace(a, "\\\\W", "x")', + 'regexp_replace(a, "[a-zA-Z_0-9]", "x")', + 'regexp_replace(a, "[^a-zA-Z_0-9]", "x")', + ), + conf=_regexp_conf) + +def test_predefined_character_classes(): + gen = mk_str_gen('[a-zA-Z]{0,2}[\r\n!"#$%&\'()*+,-./:;<=>?@[\\]^_`{|}~]{0,2}[0-9]{0,2}') + assert_gpu_and_cpu_are_equal_collect( + lambda spark: unary_op_df(spark, gen).selectExpr( + 'regexp_replace(a, "\\\\p{Lower}", "x")', + 'regexp_replace(a, "\\\\p{Upper}", "x")', + 'regexp_replace(a, "\\\\p{ASCII}", "x")', + 'regexp_replace(a, "\\\\p{Alpha}", "x")', + 'regexp_replace(a, "\\\\p{Digit}", "x")', + 'regexp_replace(a, "\\\\p{Alnum}", "x")', + 'regexp_replace(a, "\\\\p{Punct}", "x")', + 'regexp_replace(a, "\\\\p{Graph}", "x")', + 'regexp_replace(a, "\\\\p{Print}", "x")', + 'regexp_replace(a, "\\\\p{Blank}", "x")', + 'regexp_replace(a, "\\\\p{Cntrl}", "x")', + 'regexp_replace(a, "\\\\p{XDigit}", "x")', + 'regexp_replace(a, "\\\\p{Space}", "x")', + 'regexp_replace(a, "\\\\P{Lower}", "x")', + 'regexp_replace(a, "\\\\P{Upper}", "x")', + 'regexp_replace(a, "\\\\P{ASCII}", "x")', + 'regexp_replace(a, "\\\\P{Alpha}", "x")', + 'regexp_replace(a, "\\\\P{Digit}", "x")', + 'regexp_replace(a, "\\\\P{Alnum}", "x")', + 'regexp_replace(a, "\\\\P{Punct}", "x")', + 'regexp_replace(a, "\\\\P{Graph}", "x")', + 'regexp_replace(a, "\\\\P{Print}", "x")', + 'regexp_replace(a, "\\\\P{Blank}", "x")', + 'regexp_replace(a, "\\\\P{Cntrl}", "x")', + 'regexp_replace(a, "\\\\P{XDigit}", "x")', + 'regexp_replace(a, "\\\\P{Space}", "x")', + ), + conf=_regexp_conf) + +def test_rlike(): + gen = mk_str_gen('[abcd]{1,3}') + assert_gpu_and_cpu_are_equal_collect( + lambda spark: unary_op_df(spark, gen).selectExpr( + 'a rlike "a{2}"', + 'a rlike "a{1,3}"', + 'a rlike "a{1,}"', + 'a rlike "a[bc]d"'), + conf=_regexp_conf) + +def test_rlike_embedded_null(): + gen = mk_str_gen('[abcd]{1,3}')\ + .with_special_case('\u0000aaa') + assert_gpu_and_cpu_are_equal_collect( + lambda spark: unary_op_df(spark, gen).selectExpr( + 'a rlike "a{2}"', + 'a rlike "a{1,3}"', + 'a rlike "a{1,}"', + 'a rlike "a[bc]d"'), + conf=_regexp_conf) + +def test_rlike_null_pattern(): + gen = mk_str_gen('[abcd]{1,3}') + # Spark optimizes out `RLIKE NULL` in this test + assert_gpu_and_cpu_are_equal_collect( + lambda spark: unary_op_df(spark, gen).selectExpr( + 'a rlike NULL')) + +@allow_non_gpu('ProjectExec', 'RLike') +def test_rlike_fallback_empty_group(): + gen = mk_str_gen('[abcd]{1,3}') + assert_gpu_fallback_collect( + lambda spark: unary_op_df(spark, gen).selectExpr( + 'a rlike "a()?"'), + 'RLike', + conf=_regexp_conf) + +def test_rlike_escape(): + gen = mk_str_gen('[ab]{0,2}[\\-\\+]{0,2}') + assert_gpu_and_cpu_are_equal_collect( + lambda spark: unary_op_df(spark, gen).selectExpr( + 'a rlike "a[\\\\-]"'), + conf=_regexp_conf) + +def test_rlike_multi_line(): + gen = mk_str_gen('[abc]\n[def]') + assert_gpu_and_cpu_are_equal_collect( + lambda spark: unary_op_df(spark, gen).selectExpr( + 'a rlike "^a"', + 'a rlike "^d"', + 'a rlike "c\\z"', + 'a rlike "e\\z"'), + conf=_regexp_conf) + +def test_rlike_missing_escape(): + gen = mk_str_gen('a[\\-\\+]') + assert_gpu_and_cpu_are_equal_collect( + lambda spark: unary_op_df(spark, gen).selectExpr( + 'a rlike "a[-]"', + 'a rlike "a[+-]"', + 'a rlike "a[a-b-]"'), + conf=_regexp_conf) + +@allow_non_gpu('ProjectExec', 'RLike') +def test_rlike_fallback_possessive_quantifier(): + gen = mk_str_gen('(\u20ac|\\w){0,3}a[|b*.$\r\n]{0,2}c\\w{0,3}') + assert_gpu_fallback_collect( + lambda spark: unary_op_df(spark, gen).selectExpr( + 'a rlike "a*+"'), + 'RLike', + conf=_regexp_conf) + +def test_regexp_extract_all_idx_zero(): + gen = mk_str_gen('[abcd]{0,3}[0-9]{0,3}-[0-9]{0,3}[abcd]{1,3}') + assert_gpu_and_cpu_are_equal_collect( + lambda spark: unary_op_df(spark, gen).selectExpr( + 'regexp_extract_all(a, "([a-d]+).*([0-9])", 0)', + 'regexp_extract_all(a, "(a)(b)", 0)', + 'regexp_extract_all(a, "([a-z0-9]([abcd]))", 0)', + 'regexp_extract_all(a, "(\\\\d+)-(\\\\d+)", 0)', + ), + conf=_regexp_conf) + +def test_regexp_extract_all_idx_positive(): + gen = mk_str_gen('[abcd]{0,3}[0-9]{0,3}-[0-9]{0,3}[abcd]{1,3}') + assert_gpu_and_cpu_are_equal_collect( + lambda spark: unary_op_df(spark, gen).selectExpr( + 'regexp_extract_all(a, "([a-d]+).*([0-9])", 1)', + 'regexp_extract_all(a, "(a)(b)", 2)', + 'regexp_extract_all(a, "([a-z0-9]((([abcd](\\\\d?)))))", 3)', + 'regexp_extract_all(a, "(\\\\d+)-(\\\\d+)", 2)', + ), + conf=_regexp_conf) + +@allow_non_gpu('ProjectExec', 'RegExpExtractAll') +def test_regexp_extract_all_idx_negative(): + gen = mk_str_gen('[abcd]{0,3}') + assert_gpu_and_cpu_error( + lambda spark: unary_op_df(spark, gen).selectExpr( + 'regexp_extract_all(a, "(a)", -1)' + ).collect(), + error_message="The specified group index cannot be less than zero", + conf=_regexp_conf) + +@allow_non_gpu('ProjectExec', 'RegExpExtractAll') +def test_regexp_extract_all_idx_out_of_bounds(): + gen = mk_str_gen('[abcd]{0,3}') + assert_gpu_and_cpu_error( + lambda spark: unary_op_df(spark, gen).selectExpr( + 'regexp_extract_all(a, "([a-d]+).*([0-9])", 3)' + ).collect(), + error_message="Regex group count is 2, but the specified group index is 3", + conf=_regexp_conf) + +def test_rlike_unicode_support(): + gen = mk_str_gen('a[\ud720\ud800\ud900]')\ + .with_special_case('a䤫畍킱곂⬡❽ࢅ獰᳌蛫青') + + assert_gpu_and_cpu_are_equal_collect( + lambda spark: unary_op_df(spark, gen).selectExpr( + 'a rlike "a*"', + 'a rlike "a\ud720"', + 'a rlike "a\ud720.+$"'), + conf=_regexp_conf) + +def test_regexp_replace_unicode_support(): + gen = mk_str_gen('TEST[85\ud720\ud800\ud900]')\ + .with_special_case('TEST䤫畍킱곂⬡❽ࢅ獰᳌蛫青') + + assert_gpu_and_cpu_are_equal_collect( + lambda spark: unary_op_df(spark, gen).selectExpr( + 'REGEXP_REPLACE(a, "TEST\ud720", "PROD")', + 'REGEXP_REPLACE(a, "TEST\\\\b", "PROD")', + 'REGEXP_REPLACE(a, "TEST\\\\B", "PROD")', + 'REGEXP_REPLACE(a, "TEST䤫", "PROD")', + 'REGEXP_REPLACE(a, "TEST[䤫]", "PROD")', + 'REGEXP_REPLACE(a, "TEST.*\\\\d", "PROD")', + 'REGEXP_REPLACE(a, "TEST.+$", "PROD")', + ), + conf=_regexp_conf) + +def test_regexp_split_unicode_support(): + data_gen = mk_str_gen('([bf]o{0,2}青){1,7}') \ + .with_special_case('boo青and青foo') + assert_gpu_and_cpu_are_equal_collect( + lambda spark : unary_op_df(spark, data_gen).selectExpr( + 'split(a, "[青]", -1)', + 'split(a, "[o青]", -1)', + 'split(a, "[^青]", -1)', + 'split(a, "[^o]", -1)', + 'split(a, "[o]{1,2}", -1)', + 'split(a, "[bf]", -1)', + 'split(a, "[o]", -2)'), + conf=_regexp_conf) \ No newline at end of file diff --git a/integration_tests/src/main/python/string_test.py b/integration_tests/src/main/python/string_test.py index ed4aa92640d..e421e11cafa 100644 --- a/integration_tests/src/main/python/string_test.py +++ b/integration_tests/src/main/python/string_test.py @@ -73,164 +73,6 @@ def test_split_positive_limit(): 'split(a, "C", 3)', 'split(a, "_", 999)')) -def test_split_re_negative_limit(): - data_gen = mk_str_gen('([bf]o{0,2}:){1,7}') \ - .with_special_case('boo:and:foo') - assert_gpu_and_cpu_are_equal_collect( - lambda spark : unary_op_df(spark, data_gen).selectExpr( - 'split(a, "[:]", -1)', - 'split(a, "[o:]", -1)', - 'split(a, "[^:]", -1)', - 'split(a, "[^o]", -1)', - 'split(a, "[o]{1,2}", -1)', - 'split(a, "[bf]", -1)', - 'split(a, "[o]", -2)'), - conf=_regexp_conf) - -def test_split_re_zero_limit(): - data_gen = mk_str_gen('([bf]o{0,2}:){1,7}') \ - .with_special_case('boo:and:foo') - assert_gpu_and_cpu_are_equal_collect( - lambda spark : unary_op_df(spark, data_gen).selectExpr( - 'split(a, "[:]", 0)', - 'split(a, "[o:]", 0)', - 'split(a, "[^:]", 0)', - 'split(a, "[^o]", 0)', - 'split(a, "[o]{1,2}", 0)', - 'split(a, "[bf]", 0)', - 'split(a, "[o]", 0)'), - conf=_regexp_conf) - -def test_split_re_one_limit(): - data_gen = mk_str_gen('([bf]o{0,2}:){1,7}') \ - .with_special_case('boo:and:foo') - assert_gpu_and_cpu_are_equal_collect( - lambda spark : unary_op_df(spark, data_gen).selectExpr( - 'split(a, "[:]", 1)', - 'split(a, "[o:]", 1)', - 'split(a, "[^:]", 1)', - 'split(a, "[^o]", 1)', - 'split(a, "[o]{1,2}", 1)', - 'split(a, "[bf]", 1)', - 'split(a, "[o]", 1)'), - conf=_regexp_conf) - -def test_split_re_positive_limit(): - data_gen = mk_str_gen('([bf]o{0,2}:){1,7}') \ - .with_special_case('boo:and:foo') - assert_gpu_and_cpu_are_equal_collect( - lambda spark : unary_op_df(spark, data_gen).selectExpr( - 'split(a, "[:]", 2)', - 'split(a, "[o:]", 5)', - 'split(a, "[^:]", 2)', - 'split(a, "[^o]", 55)', - 'split(a, "[o]{1,2}", 999)', - 'split(a, "[bf]", 2)', - 'split(a, "[o]", 5)'), - conf=_regexp_conf) - -def test_split_re_no_limit(): - data_gen = mk_str_gen('([bf]o{0,2}:){1,7}') \ - .with_special_case('boo:and:foo') - assert_gpu_and_cpu_are_equal_collect( - lambda spark : unary_op_df(spark, data_gen).selectExpr( - 'split(a, "[:]")', - 'split(a, "[o:]")', - 'split(a, "[^:]")', - 'split(a, "[^o]")', - 'split(a, "[o]{1,2}")', - 'split(a, "[bf]")', - 'split(a, "[o]")'), - conf=_regexp_conf) - -def test_split_optimized_no_re(): - data_gen = mk_str_gen('([bf]o{0,2}[.?+\\^$|{}]{1,2}){1,7}') \ - .with_special_case('boo.and.foo') \ - .with_special_case('boo?and?foo') \ - .with_special_case('boo+and+foo') \ - .with_special_case('boo^and^foo') \ - .with_special_case('boo$and$foo') \ - .with_special_case('boo|and|foo') \ - .with_special_case('boo{and}foo') \ - .with_special_case('boo$|and$|foo') - assert_gpu_and_cpu_are_equal_collect( - lambda spark : unary_op_df(spark, data_gen).selectExpr( - 'split(a, "\\\\.")', - 'split(a, "\\\\?")', - 'split(a, "\\\\+")', - 'split(a, "\\\\^")', - 'split(a, "\\\\$")', - 'split(a, "\\\\|")', - 'split(a, "\\\\{")', - 'split(a, "\\\\}")', - 'split(a, "\\\\$\\\\|")'), - conf=_regexp_conf) - -def test_split_optimized_no_re_combined(): - data_gen = mk_str_gen('([bf]o{0,2}[AZ.?+\\^$|{}]{1,2}){1,7}') \ - .with_special_case('booA.ZandA.Zfoo') \ - .with_special_case('booA?ZandA?Zfoo') \ - .with_special_case('booA+ZandA+Zfoo') \ - .with_special_case('booA^ZandA^Zfoo') \ - .with_special_case('booA$ZandA$Zfoo') \ - .with_special_case('booA|ZandA|Zfoo') \ - .with_special_case('boo{Zand}Zfoo') - assert_gpu_and_cpu_are_equal_collect( - lambda spark : unary_op_df(spark, data_gen).selectExpr( - 'split(a, "A\\\\.Z")', - 'split(a, "A\\\\?Z")', - 'split(a, "A\\\\+Z")', - 'split(a, "A\\\\^Z")', - 'split(a, "A\\\\$Z")', - 'split(a, "A\\\\|Z")', - 'split(a, "\\\\{Z")', - 'split(a, "\\\\}Z")'), - conf=_regexp_conf) - -def test_split_regexp_disabled_no_fallback(): - conf = { 'spark.rapids.sql.regexp.enabled': 'false' } - data_gen = mk_str_gen('([bf]o{0,2}[.?+\\^$|&_]{1,2}){1,7}') \ - .with_special_case('boo.and.foo') \ - .with_special_case('boo?and?foo') \ - .with_special_case('boo+and+foo') \ - .with_special_case('boo^and^foo') \ - .with_special_case('boo$and$foo') \ - .with_special_case('boo|and|foo') \ - .with_special_case('boo&and&foo') \ - .with_special_case('boo_and_foo') - assert_gpu_and_cpu_are_equal_collect( - lambda spark : unary_op_df(spark, data_gen).selectExpr( - 'split(a, "\\\\.")', - 'split(a, "\\\\?")', - 'split(a, "\\\\+")', - 'split(a, "\\\\^")', - 'split(a, "\\\\$")', - 'split(a, "\\\\|")', - 'split(a, "&")', - 'split(a, "_")', - ), conf - ) - -@allow_non_gpu('ProjectExec', 'StringSplit') -def test_split_regexp_disabled_fallback(): - conf = { 'spark.rapids.sql.regexp.enabled': 'false' } - data_gen = mk_str_gen('([bf]o{0,2}:){1,7}') \ - .with_special_case('boo:and:foo') - assert_gpu_sql_fallback_collect( - lambda spark : unary_op_df(spark, data_gen), - 'StringSplit', - 'string_split_table', - 'select ' + - 'split(a, "[:]", 2), ' + - 'split(a, "[o:]", 5), ' + - 'split(a, "[^:]", 2), ' + - 'split(a, "[^o]", 55), ' + - 'split(a, "[o]{1,2}", 999), ' + - 'split(a, "[bf]", 2), ' + - 'split(a, "[o]", 5) from string_split_table', - conf) - - @pytest.mark.parametrize('data_gen,delim', [(mk_str_gen('([ABC]{0,3}_?){0,7}'), '_'), (mk_str_gen('([MNP_]{0,3}\\.?){0,5}'), '.'), (mk_str_gen('([123]{0,3}\\^?){0,5}'), '^')], ids=idfn) @@ -528,139 +370,6 @@ def test_replace(): 'REPLACE(a, NULL, "PROD")', 'REPLACE(a, "T", "")')) -def test_re_replace(): - gen = mk_str_gen('.{0,5}TEST[\ud720 A]{0,5}') - assert_gpu_and_cpu_are_equal_collect( - lambda spark: unary_op_df(spark, gen).selectExpr( - 'REGEXP_REPLACE(a, "TEST", "PROD")', - 'REGEXP_REPLACE(a, "^TEST", "PROD")', - 'REGEXP_REPLACE(a, "^TEST\\z", "PROD")', - 'REGEXP_REPLACE(a, "TEST\\z", "PROD")', - 'REGEXP_REPLACE(a, "\\zTEST", "PROD")', - 'REGEXP_REPLACE(a, "TEST\\z", "PROD")', - 'REGEXP_REPLACE(a, "\\^TEST\\z", "PROD")', - 'REGEXP_REPLACE(a, "\\^TEST\\z", "PROD")', - 'REGEXP_REPLACE(a, "TEST", "")', - 'REGEXP_REPLACE(a, "TEST", "%^[]\ud720")', - 'REGEXP_REPLACE(a, "TEST", NULL)'), - conf=_regexp_conf) - -# We have shims to support empty strings for zero-repetition patterns -# See https://github.com/NVIDIA/spark-rapids/issues/5456 -def test_re_replace_repetition(): - gen = mk_str_gen('.{0,5}TEST[\ud720 A]{0,5}') - assert_gpu_and_cpu_are_equal_collect( - lambda spark: unary_op_df(spark, gen).selectExpr( - 'REGEXP_REPLACE(a, "[E]+", "PROD")', - 'REGEXP_REPLACE(a, "[A]+", "PROD")', - 'REGEXP_REPLACE(a, "A{0,}", "PROD")', - 'REGEXP_REPLACE(a, "T?E?", "PROD")', - 'REGEXP_REPLACE(a, "A*", "PROD")', - 'REGEXP_REPLACE(a, "A{0,5}", "PROD")', - 'REGEXP_REPLACE(a, "(A*)", "PROD")', - 'REGEXP_REPLACE(a, "(((A*)))", "PROD")', - 'REGEXP_REPLACE(a, "((A*)E?)", "PROD")', - 'REGEXP_REPLACE(a, "[A-Z]?", "PROD")' - ), - conf=_regexp_conf) - - -@allow_non_gpu('ProjectExec', 'RegExpReplace') -def test_re_replace_issue_5492(): - # https://github.com/NVIDIA/spark-rapids/issues/5492 - gen = mk_str_gen('.{0,5}TEST[\ud720 A]{0,5}') - assert_gpu_fallback_collect( - lambda spark: unary_op_df(spark, gen).selectExpr( - 'REGEXP_REPLACE(a, "[^\\\\sa-zA-Z0-9]", "x")'), - 'RegExpReplace', - conf=_regexp_conf) - -def test_re_replace_backrefs(): - gen = mk_str_gen('.{0,5}TEST[\ud720 A]{0,5}TEST') - assert_gpu_and_cpu_are_equal_collect( - lambda spark: unary_op_df(spark, gen).selectExpr( - 'REGEXP_REPLACE(a, "(TEST)", "$1")', - 'REGEXP_REPLACE(a, "(TEST)", "[$0]")', - 'REGEXP_REPLACE(a, "(TEST)", "[\\1]")', - 'REGEXP_REPLACE(a, "(T)[a-z]+(T)", "[$2][$1][$0]")', - 'REGEXP_REPLACE(a, "([0-9]+)(T)[a-z]+(T)", "[$3][$2][$1]")', - 'REGEXP_REPLACE(a, "(.)([0-9]+TEST)", "$0 $1 $2")', - 'REGEXP_REPLACE(a, "(TESTT)", "\\0 \\1")' # no match - ), - conf=_regexp_conf) - -def test_re_replace_anchors(): - gen = mk_str_gen('.{0,2}TEST[\ud720 A]{0,5}TEST[\r\n\u0085\u2028\u2029]?') \ - .with_special_case("TEST") \ - .with_special_case("TEST\n") \ - .with_special_case("TEST\r\n") \ - .with_special_case("TEST\r") - assert_gpu_and_cpu_are_equal_collect( - lambda spark: unary_op_df(spark, gen).selectExpr( - 'REGEXP_REPLACE(a, "TEST$", "")', - 'REGEXP_REPLACE(a, "TEST$", "PROD")', - 'REGEXP_REPLACE(a, "\ud720[A-Z]+$", "PROD")', - 'REGEXP_REPLACE(a, "(\ud720[A-Z]+)$", "PROD")', - 'REGEXP_REPLACE(a, "(TEST)$", "$1")', - 'REGEXP_REPLACE(a, "^(TEST)$", "$1")', - 'REGEXP_REPLACE(a, "\\\\ATEST\\\\Z", "PROD")', - 'REGEXP_REPLACE(a, "\\\\ATEST$", "PROD")', - 'REGEXP_REPLACE(a, "^TEST\\\\Z", "PROD")', - 'REGEXP_REPLACE(a, "TEST\\\\Z", "PROD")', - 'REGEXP_REPLACE(a, "TEST\\\\z", "PROD")', - 'REGEXP_REPLACE(a, "\\\\zTEST", "PROD")', - 'REGEXP_REPLACE(a, "^TEST$", "PROD")', - 'REGEXP_REPLACE(a, "^TEST\\\\z", "PROD")', - 'REGEXP_REPLACE(a, "TEST\\\\z", "PROD")', - ), - conf=_regexp_conf) - -# For GPU runs, cuDF will check the range and throw exception if index is out of range -def test_re_replace_backrefs_idx_out_of_bounds(): - gen = mk_str_gen('.{0,5}TEST[\ud720 A]{0,5}') - assert_gpu_and_cpu_error(lambda spark: unary_op_df(spark, gen).selectExpr( - 'REGEXP_REPLACE(a, "(T)(E)(S)(T)", "[$5]")').collect(), - conf=_regexp_conf, - error_message='') - -def test_re_replace_backrefs_escaped(): - gen = mk_str_gen('.{0,5}TEST[\ud720 A]{0,5}') - assert_gpu_and_cpu_are_equal_collect( - lambda spark: unary_op_df(spark, gen).selectExpr( - 'REGEXP_REPLACE(a, "(TEST)", "[\\\\$0]")', - 'REGEXP_REPLACE(a, "(TEST)", "[\\\\$1]")'), - conf=_regexp_conf) - -def test_re_replace_escaped(): - gen = mk_str_gen('.{0,5}TEST[\ud720 A]{0,5}') - assert_gpu_and_cpu_are_equal_collect( - lambda spark: unary_op_df(spark, gen).selectExpr( - 'REGEXP_REPLACE(a, "[A-Z]+", "\\\\A\\A\\\\t\\\\r\\\\n\\t\\r\\n")'), - conf=_regexp_conf) - -def test_re_replace_null(): - gen = mk_str_gen('[\u0000 ]{0,2}TE[\u0000 ]{0,2}ST[\u0000 ]{0,2}')\ - .with_special_case("\u0000")\ - .with_special_case("\u0000\u0000") - assert_gpu_and_cpu_are_equal_collect( - lambda spark: unary_op_df(spark, gen).selectExpr( - 'REGEXP_REPLACE(a, "\u0000", "")', - 'REGEXP_REPLACE(a, "\000", "")', - 'REGEXP_REPLACE(a, "\00", "")', - 'REGEXP_REPLACE(a, "\x00", "")', - 'REGEXP_REPLACE(a, "\0", "")', - 'REGEXP_REPLACE(a, "\u0000", "NULL")', - 'REGEXP_REPLACE(a, "\000", "NULL")', - 'REGEXP_REPLACE(a, "\00", "NULL")', - 'REGEXP_REPLACE(a, "\x00", "NULL")', - 'REGEXP_REPLACE(a, "\0", "NULL")', - 'REGEXP_REPLACE(a, "TE\u0000ST", "PROD")', - 'REGEXP_REPLACE(a, "TE\u0000\u0000ST", "PROD")', - 'REGEXP_REPLACE(a, "[\x00TEST]", "PROD")', - 'REGEXP_REPLACE(a, "[TE\00ST]", "PROD")', - 'REGEXP_REPLACE(a, "[\u0000-z]", "PROD")'), - conf=_regexp_conf) - def test_length(): gen = mk_str_gen('.{0,5}TEST[\ud720 A]{0,5}') assert_gpu_and_cpu_are_equal_collect( @@ -778,400 +487,3 @@ def test_like_complex_escape(): 'a like "_oo"'), conf={'spark.sql.parser.escapedStringLiterals': 'true'}) -def test_regexp_replace(): - gen = mk_str_gen('[abcd]{0,3}') - assert_gpu_and_cpu_are_equal_collect( - lambda spark: unary_op_df(spark, gen).selectExpr( - 'regexp_replace(a, "a", "A")', - 'regexp_replace(a, "[^xyz]", "A")', - 'regexp_replace(a, "([^x])|([^y])", "A")', - 'regexp_replace(a, "(?:aa)+", "A")', - 'regexp_replace(a, "a|b|c", "A")'), - conf=_regexp_conf) - -@pytest.mark.skipif(is_before_spark_320(), reason='regexp is synonym for RLike starting in Spark 3.2.0') -def test_regexp(): - gen = mk_str_gen('[abcd]{1,3}') - assert_gpu_and_cpu_are_equal_collect( - lambda spark: unary_op_df(spark, gen).selectExpr( - 'regexp(a, "a{2}")', - 'regexp(a, "a{1,3}")', - 'regexp(a, "a{1,}")', - 'regexp(a, "a[bc]d")'), - conf=_regexp_conf) - -@pytest.mark.skipif(is_before_spark_320(), reason='regexp_like is synonym for RLike starting in Spark 3.2.0') -def test_regexp_like(): - gen = mk_str_gen('[abcd]{1,3}') - assert_gpu_and_cpu_are_equal_collect( - lambda spark: unary_op_df(spark, gen).selectExpr( - 'regexp_like(a, "a{2}")', - 'regexp_like(a, "a{1,3}")', - 'regexp_like(a, "a{1,}")', - 'regexp_like(a, "a[bc]d")'), - conf=_regexp_conf) - -def test_regexp_replace_character_set_negated(): - gen = mk_str_gen('[abcd]{0,3}[\r\n]{0,2}[abcd]{0,3}') - assert_gpu_and_cpu_are_equal_collect( - lambda spark: unary_op_df(spark, gen).selectExpr( - 'regexp_replace(a, "([^a])|([^b])", "1")', - 'regexp_replace(a, "[^a]", "1")', - 'regexp_replace(a, "([^a]|[\r\n])", "1")', - 'regexp_replace(a, "[^a\r\n]", "1")', - 'regexp_replace(a, "[^a\r]", "1")', - 'regexp_replace(a, "[^a\n]", "1")', - 'regexp_replace(a, "[^\r\n]", "1")', - 'regexp_replace(a, "[^\r]", "1")', - 'regexp_replace(a, "[^\n]", "1")'), - conf=_regexp_conf) - -def test_regexp_extract(): - gen = mk_str_gen('[abcd]{1,3}[0-9]{1,3}[abcd]{1,3}') - assert_gpu_and_cpu_are_equal_collect( - lambda spark: unary_op_df(spark, gen).selectExpr( - 'regexp_extract(a, "([0-9]+)", 1)', - 'regexp_extract(a, "([0-9])([abcd]+)", 1)', - 'regexp_extract(a, "([0-9])([abcd]+)", 2)', - 'regexp_extract(a, "^([a-d]*)([0-9]*)([a-d]*)\\z", 1)', - 'regexp_extract(a, "^([a-d]*)([0-9]*)([a-d]*)\\z", 2)', - 'regexp_extract(a, "^([a-d]*)([0-9]*)([a-d]*)\\z", 3)'), - conf=_regexp_conf) - -def test_regexp_extract_no_match(): - gen = mk_str_gen('[abcd]{1,3}[0-9]{1,3}[abcd]{1,3}') - assert_gpu_and_cpu_are_equal_collect( - lambda spark: unary_op_df(spark, gen).selectExpr( - 'regexp_extract(a, "^([0-9]+)([a-z]+)([0-9]+)\\z", 0)', - 'regexp_extract(a, "^([0-9]+)([a-z]+)([0-9]+)\\z", 1)', - 'regexp_extract(a, "^([0-9]+)([a-z]+)([0-9]+)\\z", 2)', - 'regexp_extract(a, "^([0-9]+)([a-z]+)([0-9]+)\\z", 3)'), - conf=_regexp_conf) - -# if we determine that the index is out of range we fall back to CPU and let -# Spark take care of the error handling -@allow_non_gpu('ProjectExec', 'RegExpExtract') -def test_regexp_extract_idx_negative(): - gen = mk_str_gen('[abcd]{1,3}[0-9]{1,3}[abcd]{1,3}') - assert_gpu_and_cpu_error( - lambda spark: unary_op_df(spark, gen).selectExpr( - 'regexp_extract(a, "^([a-d]*)([0-9]*)([a-d]*)$", -1)').collect(), - error_message = "The specified group index cannot be less than zero", - conf=_regexp_conf) - -# if we determine that the index is out of range we fall back to CPU and let -# Spark take care of the error handling -@allow_non_gpu('ProjectExec', 'RegExpExtract') -def test_regexp_extract_idx_out_of_bounds(): - gen = mk_str_gen('[abcd]{1,3}[0-9]{1,3}[abcd]{1,3}') - assert_gpu_and_cpu_error( - lambda spark: unary_op_df(spark, gen).selectExpr( - 'regexp_extract(a, "^([a-d]*)([0-9]*)([a-d]*)$", 4)').collect(), - error_message = "Regex group count is 3, but the specified group index is 4", - conf=_regexp_conf) - -def test_regexp_extract_multiline(): - gen = mk_str_gen('[abcd]{2}[\r\n]{0,2}[0-9]{2}[\r\n]{0,2}[abcd]{2}') - assert_gpu_and_cpu_are_equal_collect( - lambda spark: unary_op_df(spark, gen).selectExpr( - 'regexp_extract(a, "^([a-d]*)([\r\n]*)", 2)'), - conf=_regexp_conf) - -def test_regexp_extract_multiline_negated_character_class(): - gen = mk_str_gen('[abcd]{2}[\r\n]{0,2}[0-9]{2}[\r\n]{0,2}[abcd]{2}') - assert_gpu_and_cpu_are_equal_collect( - lambda spark: unary_op_df(spark, gen).selectExpr( - 'regexp_extract(a, "^([a-d]*)([^a-z]*)([a-d]*)\\z", 2)'), - conf=_regexp_conf) - -def test_regexp_extract_idx_0(): - gen = mk_str_gen('[abcd]{1,3}[0-9]{1,3}[abcd]{1,3}') - assert_gpu_and_cpu_are_equal_collect( - lambda spark: unary_op_df(spark, gen).selectExpr( - 'regexp_extract(a, "([0-9]+)[abcd]([abcd]+)", 0)', - 'regexp_extract(a, "^([a-d]*)([0-9]*)([a-d]*)\\z", 0)', - 'regexp_extract(a, "^([a-d]*)[0-9]*([a-d]*)\\z", 0)'), - conf=_regexp_conf) - -def test_word_boundaries(): - gen = StringGen('([abc]{1,3}[\r\n\t \f]{0,2}[123]){1,5}') - assert_gpu_and_cpu_are_equal_collect( - lambda spark: unary_op_df(spark, gen).selectExpr( - 'rlike(a, "\\\\b")', - 'rlike(a, "\\\\B")', - 'rlike(a, "\\\\b\\\\B")', - 'regexp_extract(a, "([a-d]+)\\\\b([e-h]+)", 1)', - 'regexp_extract(a, "([a-d]+)\\\\B", 1)', - 'regexp_replace(a, "\\\\b", "#")', - 'regexp_replace(a, "\\\\B", "#")', - ), - conf=_regexp_conf) - -def test_character_classes(): - gen = mk_str_gen('[abcd]{1,3}[0-9]{1,3}[abcd]{1,3}[ \n\t\r]{0,2}') - assert_gpu_and_cpu_are_equal_collect( - lambda spark: unary_op_df(spark, gen).selectExpr( - 'rlike(a, "[abcd]")', - 'rlike(a, "[^\n\r]")', - 'rlike(a, "[\n-\\]")', - 'rlike(a, "[+--]")', - 'regexp_extract(a, "[123]", 0)', - 'regexp_replace(a, "[\\\\0101-\\\\0132]", "@")', - 'regexp_replace(a, "[\\\\x41-\\\\x5a]", "@")', - ), - conf=_regexp_conf) - -def test_regexp_hexadecimal_digits(): - gen = mk_str_gen( - '[abcd]\\\\x00\\\\x7f\\\\x80\\\\xff\\\\x{10ffff}\\\\x{00eeee}[\\\\xa0-\\\\xb0][abcd]') - assert_gpu_and_cpu_are_equal_collect( - lambda spark: unary_op_df(spark, gen).selectExpr( - 'rlike(a, "\\\\x7f")', - 'rlike(a, "\\\\x80")', - 'rlike(a, "[\\\\xa0-\\\\xf0]")', - 'rlike(a, "\\\\x{00eeee}")', - 'regexp_extract(a, "([a-d]+)\\\\xa0([a-d]+)", 1)', - 'regexp_extract(a, "([a-d]+)[\\\\xa0\nabcd]([a-d]+)", 1)', - 'regexp_replace(a, "\\\\xff", "@")', - 'regexp_replace(a, "[\\\\xa0-\\\\xb0]", "@")', - 'regexp_replace(a, "\\\\x{10ffff}", "@")', - ), - conf=_regexp_conf) - -def test_regexp_whitespace(): - gen = mk_str_gen('\u001e[abcd]\t\n{1,3} [0-9]\n {1,3}\x0b\t[abcd]\r\f[0-9]{0,10}') - assert_gpu_and_cpu_are_equal_collect( - lambda spark: unary_op_df(spark, gen).selectExpr( - 'rlike(a, "\\\\s")', - 'rlike(a, "\\\\s{3}")', - 'rlike(a, "[abcd]+\\\\s+[0-9]+")', - 'rlike(a, "\\\\S{3}")', - 'rlike(a, "[abcd]+\\\\s+\\\\S{2,3}")', - 'regexp_extract(a, "([a-d]+)(\\\\s[0-9]+)([a-d]+)", 2)', - 'regexp_extract(a, "([a-d]+)(\\\\S+)([0-9]+)", 2)', - 'regexp_extract(a, "([a-d]+)(\\\\S+)([0-9]+)", 3)', - 'regexp_replace(a, "(\\\\s+)", "@")', - 'regexp_replace(a, "(\\\\S+)", "#")', - ), - conf=_regexp_conf) - -def test_regexp_horizontal_vertical_whitespace(): - gen = mk_str_gen( - '''\xA0\u1680\u180e[abcd]\t\n{1,3} [0-9]\n {1,3}\x0b\t[abcd]\r\f[0-9]{0,10} - [\u2001-\u200a]{1,3}\u202f\u205f\u3000\x85\u2028\u2029 - ''') - assert_gpu_and_cpu_are_equal_collect( - lambda spark: unary_op_df(spark, gen).selectExpr( - 'rlike(a, "\\\\h{2}")', - 'rlike(a, "\\\\v{3}")', - 'rlike(a, "[abcd]+\\\\h+[0-9]+")', - 'rlike(a, "[abcd]+\\\\v+[0-9]+")', - 'rlike(a, "\\\\H")', - 'rlike(a, "\\\\V")', - 'rlike(a, "[abcd]+\\\\h+\\\\V{2,3}")', - 'regexp_extract(a, "([a-d]+)([0-9]+\\\\v)([a-d]+)", 2)', - 'regexp_extract(a, "([a-d]+)(\\\\H+)([0-9]+)", 2)', - 'regexp_extract(a, "([a-d]+)(\\\\V+)([0-9]+)", 3)', - 'regexp_replace(a, "(\\\\v+)", "@")', - 'regexp_replace(a, "(\\\\H+)", "#")', - ), - conf=_regexp_conf) - -def test_regexp_linebreak(): - gen = mk_str_gen( - '[abc]{1,3}\u000D\u000A[def]{1,3}[\u000A\u000B\u000C\u000D\u0085\u2028\u2029]{0,5}[123]') - assert_gpu_and_cpu_are_equal_collect( - lambda spark: unary_op_df(spark, gen).selectExpr( - 'rlike(a, "\\\\R")', - 'regexp_extract(a, "([a-d]+)(\\\\R)([a-d]+)", 1)', - 'regexp_replace(a, "\\\\R", "")', - ), - conf=_regexp_conf) - -def test_regexp_octal_digits(): - gen = mk_str_gen('[abcd]\u0000\u0041\u007f\u0080\u00ff[\\\\xa0-\\\\xb0][abcd]') - assert_gpu_and_cpu_are_equal_collect( - lambda spark: unary_op_df(spark, gen).selectExpr( - 'rlike(a, "\\\\0177")', - 'rlike(a, "\\\\0200")', - 'rlike(a, "\\\\0101")', - 'rlike(a, "[\\\\0240-\\\\0377]")', - 'regexp_extract(a, "([a-d]+)\\\\0240([a-d]+)", 1)', - 'regexp_extract(a, "([a-d]+)[\\\\0141-\\\\0172]([a-d]+)", 0)', - 'regexp_replace(a, "\\\\0377", "")', - 'regexp_replace(a, "\\\\0260", "")', - ), - conf=_regexp_conf) - -def test_regexp_replace_digit(): - gen = mk_str_gen('[a-z]{0,2}[0-9]{0,2}') \ - .with_special_case('䤫畍킱곂⬡❽ࢅ獰᳌蛫青') \ - .with_special_case('a\n2\r\n3') - assert_gpu_and_cpu_are_equal_collect( - lambda spark: unary_op_df(spark, gen).selectExpr( - 'regexp_replace(a, "\\\\d", "x")', - 'regexp_replace(a, "\\\\D", "x")', - 'regexp_replace(a, "[0-9]", "x")', - 'regexp_replace(a, "[^0-9]", "x")', - ), - conf=_regexp_conf) - -def test_regexp_replace_word(): - gen = mk_str_gen('[a-z]{0,2}[_]{0,1}[0-9]{0,2}') \ - .with_special_case('䤫畍킱곂⬡❽ࢅ獰᳌蛫青') \ - .with_special_case('a\n2\r\n3') - assert_gpu_and_cpu_are_equal_collect( - lambda spark: unary_op_df(spark, gen).selectExpr( - 'regexp_replace(a, "\\\\w", "x")', - 'regexp_replace(a, "\\\\W", "x")', - 'regexp_replace(a, "[a-zA-Z_0-9]", "x")', - 'regexp_replace(a, "[^a-zA-Z_0-9]", "x")', - ), - conf=_regexp_conf) - -def test_predefined_character_classes(): - gen = mk_str_gen('[a-zA-Z]{0,2}[\r\n!"#$%&\'()*+,-./:;<=>?@[\\]^_`{|}~]{0,2}[0-9]{0,2}') - assert_gpu_and_cpu_are_equal_collect( - lambda spark: unary_op_df(spark, gen).selectExpr( - 'regexp_replace(a, "\\\\p{Lower}", "x")', - 'regexp_replace(a, "\\\\p{Upper}", "x")', - 'regexp_replace(a, "\\\\p{ASCII}", "x")', - 'regexp_replace(a, "\\\\p{Alpha}", "x")', - 'regexp_replace(a, "\\\\p{Digit}", "x")', - 'regexp_replace(a, "\\\\p{Alnum}", "x")', - 'regexp_replace(a, "\\\\p{Punct}", "x")', - 'regexp_replace(a, "\\\\p{Graph}", "x")', - 'regexp_replace(a, "\\\\p{Print}", "x")', - 'regexp_replace(a, "\\\\p{Blank}", "x")', - 'regexp_replace(a, "\\\\p{Cntrl}", "x")', - 'regexp_replace(a, "\\\\p{XDigit}", "x")', - 'regexp_replace(a, "\\\\p{Space}", "x")', - 'regexp_replace(a, "\\\\P{Lower}", "x")', - 'regexp_replace(a, "\\\\P{Upper}", "x")', - 'regexp_replace(a, "\\\\P{ASCII}", "x")', - 'regexp_replace(a, "\\\\P{Alpha}", "x")', - 'regexp_replace(a, "\\\\P{Digit}", "x")', - 'regexp_replace(a, "\\\\P{Alnum}", "x")', - 'regexp_replace(a, "\\\\P{Punct}", "x")', - 'regexp_replace(a, "\\\\P{Graph}", "x")', - 'regexp_replace(a, "\\\\P{Print}", "x")', - 'regexp_replace(a, "\\\\P{Blank}", "x")', - 'regexp_replace(a, "\\\\P{Cntrl}", "x")', - 'regexp_replace(a, "\\\\P{XDigit}", "x")', - 'regexp_replace(a, "\\\\P{Space}", "x")', - ), - conf=_regexp_conf) - -def test_rlike(): - gen = mk_str_gen('[abcd]{1,3}') - assert_gpu_and_cpu_are_equal_collect( - lambda spark: unary_op_df(spark, gen).selectExpr( - 'a rlike "a{2}"', - 'a rlike "a{1,3}"', - 'a rlike "a{1,}"', - 'a rlike "a[bc]d"'), - conf=_regexp_conf) - -def test_rlike_embedded_null(): - gen = mk_str_gen('[abcd]{1,3}')\ - .with_special_case('\u0000aaa') - assert_gpu_and_cpu_are_equal_collect( - lambda spark: unary_op_df(spark, gen).selectExpr( - 'a rlike "a{2}"', - 'a rlike "a{1,3}"', - 'a rlike "a{1,}"', - 'a rlike "a[bc]d"'), - conf=_regexp_conf) - -def test_rlike_null_pattern(): - gen = mk_str_gen('[abcd]{1,3}') - # Spark optimizes out `RLIKE NULL` in this test - assert_gpu_and_cpu_are_equal_collect( - lambda spark: unary_op_df(spark, gen).selectExpr( - 'a rlike NULL')) - -@allow_non_gpu('ProjectExec', 'RLike') -def test_rlike_fallback_empty_group(): - gen = mk_str_gen('[abcd]{1,3}') - assert_gpu_fallback_collect( - lambda spark: unary_op_df(spark, gen).selectExpr( - 'a rlike "a()?"'), - 'RLike', - conf=_regexp_conf) - -def test_rlike_escape(): - gen = mk_str_gen('[ab]{0,2}[\\-\\+]{0,2}') - assert_gpu_and_cpu_are_equal_collect( - lambda spark: unary_op_df(spark, gen).selectExpr( - 'a rlike "a[\\\\-]"'), - conf=_regexp_conf) - -def test_rlike_multi_line(): - gen = mk_str_gen('[abc]\n[def]') - assert_gpu_and_cpu_are_equal_collect( - lambda spark: unary_op_df(spark, gen).selectExpr( - 'a rlike "^a"', - 'a rlike "^d"', - 'a rlike "c\\z"', - 'a rlike "e\\z"'), - conf=_regexp_conf) - -def test_rlike_missing_escape(): - gen = mk_str_gen('a[\\-\\+]') - assert_gpu_and_cpu_are_equal_collect( - lambda spark: unary_op_df(spark, gen).selectExpr( - 'a rlike "a[-]"', - 'a rlike "a[+-]"', - 'a rlike "a[a-b-]"'), - conf=_regexp_conf) - -@allow_non_gpu('ProjectExec', 'RLike') -def test_rlike_fallback_possessive_quantifier(): - gen = mk_str_gen('(\u20ac|\\w){0,3}a[|b*.$\r\n]{0,2}c\\w{0,3}') - assert_gpu_fallback_collect( - lambda spark: unary_op_df(spark, gen).selectExpr( - 'a rlike "a*+"'), - 'RLike', - conf=_regexp_conf) - -def test_regexp_extract_all_idx_zero(): - gen = mk_str_gen('[abcd]{0,3}[0-9]{0,3}-[0-9]{0,3}[abcd]{1,3}') - assert_gpu_and_cpu_are_equal_collect( - lambda spark: unary_op_df(spark, gen).selectExpr( - 'regexp_extract_all(a, "([a-d]+).*([0-9])", 0)', - 'regexp_extract_all(a, "(a)(b)", 0)', - 'regexp_extract_all(a, "([a-z0-9]([abcd]))", 0)', - 'regexp_extract_all(a, "(\\\\d+)-(\\\\d+)", 0)', - ), - conf=_regexp_conf) - -def test_regexp_extract_all_idx_positive(): - gen = mk_str_gen('[abcd]{0,3}[0-9]{0,3}-[0-9]{0,3}[abcd]{1,3}') - assert_gpu_and_cpu_are_equal_collect( - lambda spark: unary_op_df(spark, gen).selectExpr( - 'regexp_extract_all(a, "([a-d]+).*([0-9])", 1)', - 'regexp_extract_all(a, "(a)(b)", 2)', - 'regexp_extract_all(a, "([a-z0-9]((([abcd](\\\\d?)))))", 3)', - 'regexp_extract_all(a, "(\\\\d+)-(\\\\d+)", 2)', - ), - conf=_regexp_conf) - -@allow_non_gpu('ProjectExec', 'RegExpExtractAll') -def test_regexp_extract_all_idx_negative(): - gen = mk_str_gen('[abcd]{0,3}') - assert_gpu_and_cpu_error( - lambda spark: unary_op_df(spark, gen).selectExpr( - 'regexp_extract_all(a, "(a)", -1)' - ).collect(), - error_message="The specified group index cannot be less than zero", - conf=_regexp_conf) - -@allow_non_gpu('ProjectExec', 'RegExpExtractAll') -def test_regexp_extract_all_idx_out_of_bounds(): - gen = mk_str_gen('[abcd]{0,3}') - assert_gpu_and_cpu_error( - lambda spark: unary_op_df(spark, gen).selectExpr( - 'regexp_extract_all(a, "([a-d]+).*([0-9])", 3)' - ).collect(), - error_message="Regex group count is 2, but the specified group index is 3", - conf=_regexp_conf) diff --git a/jenkins/spark-premerge-build.sh b/jenkins/spark-premerge-build.sh index b16b25b44a9..40ae460e1b9 100755 --- a/jenkins/spark-premerge-build.sh +++ b/jenkins/spark-premerge-build.sh @@ -47,6 +47,8 @@ mvn_verify() { # don't skip tests env -u SPARK_HOME mvn -U -B $MVN_URM_MIRROR -Dbuildver=320 clean install -Drat.skip=true -Dmaven.javadoc.skip=true -Dskip -Dmaven.scalastyle.skip=true -Dcuda.version=$CUDA_CLASSIFIER -Dpytest.TEST_TAGS='' -pl '!tools' + # enable UTF-8 for regular expression tests + env -u SPARK_HOME LC_ALL="en_US.UTF-8" mvn $MVN_URM_MIRROR -Dbuildver=320 test -Drat.skip=true -Dmaven.javadoc.skip=true -Dskip -Dmaven.scalastyle.skip=true -Dcuda.version=$CUDA_CLASSIFIER -Dpytest.TEST_TAGS='' -pl '!tools' -DwildcardSuites=com.nvidia.spark.rapids.ConditionalsSuite,com.nvidia.spark.rapids.RegularExpressionSuite,com.nvidia.spark.rapids.RegularExpressionTranspilerSuite env -u SPARK_HOME mvn -U -B $MVN_URM_MIRROR -Dbuildver=321 clean install -Drat.skip=true -DskipTests -Dmaven.javadoc.skip=true -Dskip -Dmaven.scalastyle.skip=true -Dcuda.version=$CUDA_CLASSIFIER -pl aggregator -am [[ $BUILD_MAINTENANCE_VERSION_SNAPSHOTS == "true" ]] && env -u SPARK_HOME mvn -U -B $MVN_URM_MIRROR -Dbuildver=322 clean install -Drat.skip=true -DskipTests -Dmaven.javadoc.skip=true -Dskip -Dmaven.scalastyle.skip=true -Dcuda.version=$CUDA_CLASSIFIER -pl aggregator -am env -u SPARK_HOME mvn -U -B $MVN_URM_MIRROR -Dbuildver=330 clean install -Drat.skip=true -DskipTests -Dmaven.javadoc.skip=true -Dskip -Dmaven.scalastyle.skip=true -Dcuda.version=$CUDA_CLASSIFIER -pl aggregator -am @@ -124,6 +126,8 @@ ci_2() { TEST='not conditionals_test and not window_function_test and not struct_test and not time_window_test' \ ./integration_tests/run_pyspark_from_build.sh INCLUDE_SPARK_AVRO_JAR=true TEST='avro_test.py' ./integration_tests/run_pyspark_from_build.sh + # export 'LC_ALL' to set locale with UTF-8 so regular expressions are enabled + LC_ALL="en_US.UTF-8" TEST="regexp_test.py" ./integration_tests/run_pyspark_from_build.sh } diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RegexParser.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RegexParser.scala index bf17fee61fd..ed8feccda3a 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RegexParser.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RegexParser.scala @@ -954,7 +954,7 @@ class CudfRegexTranspiler(mode: RegexMode) { // NOTE: this applies to when using *standard* mode. In multiline mode, all these // conditions will change. Currently Spark does not use multiline mode. previous match { - case Some(RegexChar('$')) => + case Some(RegexChar('$')) | Some(RegexEscaped('Z')) => // repeating the line anchor in cuDF (for example b$$) causes matches to fail, but in // Java, it's treated as a single (b$ and b$$ are synonymous), so we create // an empty RegexAST that outputs to empty string @@ -1060,7 +1060,20 @@ class CudfRegexTranspiler(mode: RegexMode) { case 'b' | 'B' if mode == RegexSplitMode => // see https://github.com/NVIDIA/spark-rapids/issues/5478 throw new RegexUnsupportedException( - "Word boundaries are not supported in split mode", regex.position) + "Word boundaries are not supported in split mode", regex.position) + case 'b' | 'B' => + previous match { + case Some(RegexEscaped(ch)) if "DWSHV".contains(ch) => + throw new RegexUnsupportedException( + "Word boundaries around \\D, \\S,\\W, \\H, or \\V are not supported", + regex.position) + case Some(RegexCharacterClass(negated, _)) if negated => + throw new RegexUnsupportedException( + "Word boundaries around negated character classes are not supported", + regex.position) + case _ => + RegexEscaped(ch) + } case 'A' if mode == RegexSplitMode => throw new RegexUnsupportedException( "String anchor \\A is not supported in split mode", regex.position) diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/stringFunctions.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/stringFunctions.scala index f62a84d1ea7..2294cfb75d9 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/stringFunctions.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/stringFunctions.scala @@ -16,6 +16,8 @@ package org.apache.spark.sql.rapids +import java.nio.charset.Charset + import scala.collection.mutable.ArrayBuffer import ai.rapids.cudf.{BinaryOp, ColumnVector, ColumnView, DType, PadSide, Scalar, Table} @@ -838,6 +840,14 @@ object GpuRegExpUtils { meta.willNotWorkOnGpu(s"regular expression support is disabled. " + s"Set ${RapidsConf.ENABLE_REGEXP}=true to enable it") } + + Charset.defaultCharset().name() match { + case "UTF-8" => + // supported + case _ => + meta.willNotWorkOnGpu(s"regular expression support is disabled because the GPU only " + + "supports the UTF-8 charset when using regular expressions") + } } /** diff --git a/tests/src/test/scala/com/nvidia/spark/rapids/ConditionalsSuite.scala b/tests/src/test/scala/com/nvidia/spark/rapids/ConditionalsSuite.scala index 38f8b2c5b57..21315366111 100644 --- a/tests/src/test/scala/com/nvidia/spark/rapids/ConditionalsSuite.scala +++ b/tests/src/test/scala/com/nvidia/spark/rapids/ConditionalsSuite.scala @@ -15,6 +15,8 @@ */ package com.nvidia.spark.rapids +import java.nio.charset.Charset + import org.apache.spark.SparkConf import org.apache.spark.sql.{DataFrame, SparkSession} import org.apache.spark.sql.functions.expr @@ -26,6 +28,7 @@ class ConditionalsSuite extends SparkQueryCompareTestSuite { .set(RapidsConf.ENABLE_REGEXP.key, "true") testSparkResultsAreEqual("CASE WHEN test all branches", testData, conf) { df => + assume(isUnicodeEnabled()) df.withColumn("test", expr( "CASE " + "WHEN a RLIKE '^[0-9]{1,3}\\z' THEN CAST(a AS INT) " + @@ -34,6 +37,7 @@ class ConditionalsSuite extends SparkQueryCompareTestSuite { } testSparkResultsAreEqual("CASE WHEN first branch always true", testData2, conf) { df => + assume(isUnicodeEnabled()) df.withColumn("test", expr( "CASE " + "WHEN a RLIKE '^[0-9]{1,3}\\z' THEN CAST(a AS INT) " + @@ -42,6 +46,7 @@ class ConditionalsSuite extends SparkQueryCompareTestSuite { } testSparkResultsAreEqual("CASE WHEN second branch always true", testData2, conf) { df => + assume(isUnicodeEnabled()) df.withColumn("test", expr( "CASE " + "WHEN a RLIKE '^[0-9]{4,6}\\z' THEN CAST(a AS INT) " + @@ -50,6 +55,7 @@ class ConditionalsSuite extends SparkQueryCompareTestSuite { } testSparkResultsAreEqual("CASE WHEN else condition always true", testData2, conf) { df => + assume(isUnicodeEnabled()) df.withColumn("test", expr( "CASE " + "WHEN a RLIKE '^[0-9]{4,6}\\z' THEN CAST(a AS INT) " + @@ -58,6 +64,7 @@ class ConditionalsSuite extends SparkQueryCompareTestSuite { } testSparkResultsAreEqual("CASE WHEN first or second branch is true", testData3, conf) { df => + assume(isUnicodeEnabled()) df.withColumn("test", expr( "CASE " + "WHEN a RLIKE '^[0-9]{1,3}\\z' THEN CAST(a AS INT) " + @@ -77,6 +84,7 @@ class ConditionalsSuite extends SparkQueryCompareTestSuite { testSparkResultsAreEqual("CASE WHEN with null predicate values after first branch", testData3, conf) { df => + assume(isUnicodeEnabled()) df.withColumn("test", expr( "CASE " + "WHEN char_length(a) IS NULL THEN -999 " + @@ -114,4 +122,7 @@ class ConditionalsSuite extends SparkQueryCompareTestSuite { ).toDF("a").repartition(2) } + private def isUnicodeEnabled(): Boolean = { + Charset.defaultCharset().name() == "UTF-8" + } } diff --git a/tests/src/test/scala/com/nvidia/spark/rapids/RegularExpressionSuite.scala b/tests/src/test/scala/com/nvidia/spark/rapids/RegularExpressionSuite.scala index 077a3b8fbec..7478e112216 100644 --- a/tests/src/test/scala/com/nvidia/spark/rapids/RegularExpressionSuite.scala +++ b/tests/src/test/scala/com/nvidia/spark/rapids/RegularExpressionSuite.scala @@ -15,6 +15,8 @@ */ package com.nvidia.spark.rapids +import java.nio.charset.Charset + import com.nvidia.spark.rapids.shims.SparkShimImpl import org.apache.spark.SparkConf @@ -60,33 +62,39 @@ class RegularExpressionSuite extends SparkQueryCompareTestSuite { } testSparkResultsAreEqual("String regexp_replace regex 1", - nullableStringsFromCsv, conf = conf) { - frame => frame.selectExpr("regexp_replace(strings,'.*','D')") + nullableStringsFromCsv, conf = conf) { frame => + assume(isUnicodeEnabled()) + frame.selectExpr("regexp_replace(strings,'.*','D')") } testSparkResultsAreEqual("String regexp_replace regex 2", - nullableStringsFromCsv, conf = conf) { - frame => frame.selectExpr("regexp_replace(strings,'[a-z]+','D')") + nullableStringsFromCsv, conf = conf) { frame => + assume(isUnicodeEnabled()) + frame.selectExpr("regexp_replace(strings,'[a-z]+','D')") } testSparkResultsAreEqual("String regexp_replace regex 3", - nullableStringsFromCsv, conf = conf) { - frame => frame.selectExpr("regexp_replace(strings,'foo$','D')") + nullableStringsFromCsv, conf = conf) { frame => + assume(isUnicodeEnabled()) + frame.selectExpr("regexp_replace(strings,'foo$','D')") } testSparkResultsAreEqual("String regexp_replace regex 4", - nullableStringsFromCsv, conf = conf) { - frame => frame.selectExpr("regexp_replace(strings,'^foo','D')") + nullableStringsFromCsv, conf = conf) { frame => + assume(isUnicodeEnabled()) + frame.selectExpr("regexp_replace(strings,'^foo','D')") } testSparkResultsAreEqual("String regexp_replace regex 5", - nullableStringsFromCsv, conf = conf) { - frame => frame.selectExpr("regexp_replace(strings,'(foo)','D')") + nullableStringsFromCsv, conf = conf) { frame => + assume(isUnicodeEnabled()) + frame.selectExpr("regexp_replace(strings,'(foo)','D')") } testSparkResultsAreEqual("String regexp_replace regex 6", - nullableStringsFromCsv, conf = conf) { - frame => frame.selectExpr("regexp_replace(strings,'\\(foo\\)','D')") + nullableStringsFromCsv, conf = conf) { frame => + assume(isUnicodeEnabled()) + frame.selectExpr("regexp_replace(strings,'\\(foo\\)','D')") } // https://github.com/NVIDIA/spark-rapids/issues/5659 @@ -106,8 +114,9 @@ class RegularExpressionSuite extends SparkQueryCompareTestSuite { // note that regexp_extract with a literal string gets replaced with the literal result of // the regexp_extract call on CPU testSparkResultsAreEqual("String regexp_extract literal input", - extractStrings, conf = conf) { - frame => frame.selectExpr("regexp_extract('abc123def', '^([a-z]*)([0-9]*)([a-z]*)$', 2)") + extractStrings, conf = conf) { frame => + assume(isUnicodeEnabled()) + frame.selectExpr("regexp_extract('abc123def', '^([a-z]*)([0-9]*)([a-z]*)$', 2)") } private def extractStrings(session: SparkSession): DataFrame = { @@ -121,4 +130,8 @@ class RegularExpressionSuite extends SparkQueryCompareTestSuite { ).toDF("strings") } + private def isUnicodeEnabled(): Boolean = { + Charset.defaultCharset().name() == "UTF-8" + } + } diff --git a/tests/src/test/scala/com/nvidia/spark/rapids/RegularExpressionTranspilerSuite.scala b/tests/src/test/scala/com/nvidia/spark/rapids/RegularExpressionTranspilerSuite.scala index 73e2f86292c..dafb7d358b7 100644 --- a/tests/src/test/scala/com/nvidia/spark/rapids/RegularExpressionTranspilerSuite.scala +++ b/tests/src/test/scala/com/nvidia/spark/rapids/RegularExpressionTranspilerSuite.scala @@ -15,6 +15,7 @@ */ package com.nvidia.spark.rapids +import java.nio.charset.Charset import java.util.regex.Pattern import scala.collection.mutable.{HashSet, ListBuffer} @@ -300,6 +301,23 @@ class RegularExpressionTranspilerSuite extends FunSuite with Arm { } } + test("word boundaries around \\D, \\S, \\W, \\H, or \\V - fall back to CPU") { + val patterns = Seq("\\D\\B", "\\W\\B", "\\D\\b", "\\W\\b", "\\S\\b", "\\S\\B", "\\H\\B", + "\\H\\b", "\\V\\B", "\\V\\b") + patterns.foreach(pattern => + assertUnsupported(pattern, RegexFindMode, + "Word boundaries around \\D, \\S,\\W, \\H, or \\V are not supported") + ) + } + + test("word boundaries around negated character class - fall back to CPU") { + val patterns = Seq("[^A-Z]\\B", "[^A-Z]\\b") + patterns.foreach(pattern => + assertUnsupported(pattern, RegexFindMode, + "Word boundaries around negated character classes are not supported") + ) + } + test ("word boundaries will fall back to CPU - split") { val patterns = Seq("\\b", "\\B") patterns.foreach(pattern => @@ -583,6 +601,22 @@ class RegularExpressionTranspilerSuite extends FunSuite with Arm { RegexReplaceMode) } + test("AST fuzz test - regexp_find - full unicode input") { + assume(isUnicodeEnabled()) + doAstFuzzTest(None, REGEXP_LIMITED_CHARS_REPLACE, + RegexFindMode) + } + + test("AST fuzz test - regexp_replace - full unicode input") { + assume(isUnicodeEnabled()) + doAstFuzzTest(None, REGEXP_LIMITED_CHARS_REPLACE, + RegexReplaceMode) + } + + def isUnicodeEnabled(): Boolean = { + Charset.defaultCharset().name() == "UTF-8" + } + test("AST fuzz test - regexp_find - anchor focused") { doAstFuzzTest(validDataChars = Some("\r\nabc"), validPatternChars = "^$\\AZz\r\n()[]-", mode = RegexFindMode) @@ -708,8 +742,13 @@ class RegularExpressionTranspilerSuite extends FunSuite with Arm { val data = Range(0, 1000) .map(_ => dataGen.nextString()) + val skipUnicodeIssues = validDataChars match { + case None => true + case _ => false + } + // generate patterns that are valid on both CPU and GPU - val fuzzer = new FuzzRegExp(validPatternChars) + val fuzzer = new FuzzRegExp(validPatternChars, skipUnicodeIssues = skipUnicodeIssues) val patterns = HashSet[String]() while (patterns.size < 5000) { val pattern = fuzzer.generate(0).toRegexString @@ -880,7 +919,8 @@ class RegularExpressionTranspilerSuite extends FunSuite with Arm { * See https://docs.oracle.com/javase/8/docs/api/java/util/regex/Pattern.html for * Java regular expression syntax. */ -class FuzzRegExp(suggestedChars: String, skipKnownIssues: Boolean = true) { +class FuzzRegExp(suggestedChars: String, skipKnownIssues: Boolean = true, + skipUnicodeIssues: Boolean = false) { private val maxDepth = 5 private val rr = new Random(0) @@ -984,7 +1024,12 @@ class FuzzRegExp(suggestedChars: String, skipKnownIssues: Boolean = true) { /** Any escaped character */ private def escapedChar: RegexEscaped = { - RegexEscaped(char.ch) + var ch = '\u0000' + do { + ch = chars(rr.nextInt(chars.length)) + // see https://github.com/NVIDIA/spark-rapids/issues/5882 for \B and \b issue + } while (skipUnicodeIssues && "bB".contains(ch)) + RegexEscaped(ch) } private def lineTerminator: RegexAST = { @@ -1000,16 +1045,22 @@ class FuzzRegExp(suggestedChars: String, skipKnownIssues: Boolean = true) { } private def boundaryMatch: RegexAST = { - val generators = Seq[() => RegexAST]( + val baseGenerators = Seq[() => RegexAST]( () => RegexChar('^'), () => RegexChar('$'), - () => RegexEscaped('b'), - () => RegexEscaped('B'), () => RegexEscaped('A'), () => RegexEscaped('G'), () => RegexEscaped('Z'), () => RegexEscaped('z') ) + val generators = if (skipUnicodeIssues) { + baseGenerators + } else { + baseGenerators ++ Seq[() => RegexAST]( + // see https://github.com/NVIDIA/spark-rapids/issues/5882 for \B and \b issue + () => RegexEscaped('b'), + () => RegexEscaped('B')) + } generators(rr.nextInt(generators.length))() }