diff --git a/NOTICE b/NOTICE
index eb8aaeb9977..657bfecd06a 100644
--- a/NOTICE
+++ b/NOTICE
@@ -48,6 +48,17 @@ The Apache Software Foundation (http://www.apache.org/).
--------------------------------------------------------------------------------
+This project includes software from the Apache Gluten project
+(www.github.com/apache/incubator-gluten/).
+
+Apache Gluten (Incubating)
+Copyright (2024) The Apache Software Foundation
+
+This product includes software developed at
+The Apache Software Foundation (http://www.apache.org/).
+
+--------------------------------------------------------------------------------
+
This project includes code from Kite, developed at Cloudera, Inc. with
the following copyright notice:
diff --git a/pom.xml b/pom.xml
index b45eca7abeb..c939f8d5891 100644
--- a/pom.xml
+++ b/pom.xml
@@ -1039,6 +1039,53 @@
${mockito.version}
test
+
+ org.apache.spark
+ spark-core_${scala.binary.version}
+ ${spark.version}
+ test-jar
+ test
+
+
+ org.apache.hadoop
+ hadoop-client
+
+
+ org.apache.hadoop
+ hadoop-client-api
+
+
+ org.apache.hadoop
+ hadoop-client-runtime
+
+
+ org.apache.curator
+ curator-recipes
+
+
+ org.slf4j
+ slf4j-log4j12
+
+
+ log4j
+ log4j
+
+
+
+
+ org.apache.spark
+ spark-catalyst_${scala.binary.version}
+ ${spark.version}
+ test-jar
+ test
+
+
+ org.apache.spark
+ spark-sql_${scala.binary.version}
+ ${spark.version}
+ test-jar
+ test
+
diff --git a/scala2.13/pom.xml b/scala2.13/pom.xml
index c60e0b19034..cf3bfb48373 100644
--- a/scala2.13/pom.xml
+++ b/scala2.13/pom.xml
@@ -1039,6 +1039,53 @@
${mockito.version}
test
+
+ org.apache.spark
+ spark-core_${scala.binary.version}
+ ${spark.version}
+ test-jar
+ test
+
+
+ org.apache.hadoop
+ hadoop-client
+
+
+ org.apache.hadoop
+ hadoop-client-api
+
+
+ org.apache.hadoop
+ hadoop-client-runtime
+
+
+ org.apache.curator
+ curator-recipes
+
+
+ org.slf4j
+ slf4j-log4j12
+
+
+ log4j
+ log4j
+
+
+
+
+ org.apache.spark
+ spark-catalyst_${scala.binary.version}
+ ${spark.version}
+ test-jar
+ test
+
+
+ org.apache.spark
+ spark-sql_${scala.binary.version}
+ ${spark.version}
+ test-jar
+ test
+
diff --git a/scala2.13/tests/pom.xml b/scala2.13/tests/pom.xml
index d90842c7955..a9c1d707b27 100644
--- a/scala2.13/tests/pom.xml
+++ b/scala2.13/tests/pom.xml
@@ -103,6 +103,27 @@
org.apache.spark
spark-avro_${scala.binary.version}
+
+ org.apache.spark
+ spark-core_${scala.binary.version}
+ test-jar
+
+
+ org.apache.spark
+ spark-sql_${scala.binary.version}
+ test-jar
+
+
+ org.apache.spark
+ spark-catalyst_${scala.binary.version}
+ test-jar
+
+
+ org.scalatestplus
+ scalatestplus-scalacheck_${scala.binary.version}
+ 3.1.0.0-RC2
+ test
+
diff --git a/tests/pom.xml b/tests/pom.xml
index 71143614f7b..96c2a051143 100644
--- a/tests/pom.xml
+++ b/tests/pom.xml
@@ -103,6 +103,27 @@
org.apache.spark
spark-avro_${scala.binary.version}
+
+ org.apache.spark
+ spark-core_${scala.binary.version}
+ test-jar
+
+
+ org.apache.spark
+ spark-sql_${scala.binary.version}
+ test-jar
+
+
+ org.apache.spark
+ spark-catalyst_${scala.binary.version}
+ test-jar
+
+
+ org.scalatestplus
+ scalatestplus-scalacheck_${scala.binary.version}
+ 3.1.0.0-RC2
+ test
+
diff --git a/tests/src/test/java/com/nvidia/spark/rapids/TestStats.java b/tests/src/test/java/com/nvidia/spark/rapids/TestStats.java
new file mode 100644
index 00000000000..3f367529bcd
--- /dev/null
+++ b/tests/src/test/java/com/nvidia/spark/rapids/TestStats.java
@@ -0,0 +1,205 @@
+/*
+ * Copyright (c) 2024, NVIDIA CORPORATION.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package com.nvidia.spark.rapids;
+
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.HashSet;
+import java.util.Map;
+import java.util.Set;
+import java.util.Stack;
+
+/** Only use in UT Env. It's not thread safe. */
+public class TestStats {
+ private static final String HEADER_FORMAT = "%s | %s |
";
+ private static final String ROW_FORMAT =
+ "%s | %s | %s | %s | %s | %s |
";
+
+ private static boolean UT_ENV = false;
+ private static final Map caseInfos = new HashMap<>();
+ private static String currentCase;
+ public static int offloadRapidsUnitNumber = 0;
+ public static int testUnitNumber = 0;
+
+ // use the rapids backend to execute the query
+ public static boolean offloadRapids = true;
+ public static int suiteTestNumber = 0;
+ public static int offloadRapidsTestNumber = 0;
+
+ public static void beginStatistic() {
+ UT_ENV = true;
+ }
+
+ public static void reset() {
+ offloadRapids = false;
+ suiteTestNumber = 0;
+ offloadRapidsTestNumber = 0;
+ testUnitNumber = 0;
+ offloadRapidsUnitNumber = 0;
+ resetCase();
+ caseInfos.clear();
+ }
+
+ private static int totalSuiteTestNumber = 0;
+ public static int totalOffloadRapidsTestNumber = 0;
+
+ public static int totalTestUnitNumber = 0;
+ public static int totalOffloadRapidsCaseNumber = 0;
+
+ public static void printMarkdown(String suitName) {
+ if (!UT_ENV) {
+ return;
+ }
+
+ String title = "print_markdown_" + suitName;
+
+ String info =
+ "Case Count: %d, OffloadRapids Case Count: %d, "
+ + "Unit Count %d, OffloadRapids Unit Count %d";
+
+ System.out.println(
+ String.format(
+ HEADER_FORMAT,
+ title,
+ String.format(
+ info,
+ TestStats.suiteTestNumber,
+ TestStats.offloadRapidsTestNumber,
+ TestStats.testUnitNumber,
+ TestStats.offloadRapidsUnitNumber)));
+
+ caseInfos.forEach(
+ (key, value) ->
+ System.out.println(
+ String.format(
+ ROW_FORMAT,
+ title,
+ key,
+ value.status,
+ value.type,
+ String.join("
", value.fallbackExpressionName),
+ String.join("
", value.fallbackClassName))));
+
+ totalSuiteTestNumber += suiteTestNumber;
+ totalOffloadRapidsTestNumber += offloadRapidsTestNumber;
+ totalTestUnitNumber += testUnitNumber;
+ totalOffloadRapidsCaseNumber += offloadRapidsUnitNumber;
+ System.out.println(
+ "total_markdown_ totalCaseNum:"
+ + totalSuiteTestNumber
+ + " offloadRapids: "
+ + totalOffloadRapidsTestNumber
+ + " total unit: "
+ + totalTestUnitNumber
+ + " offload unit: "
+ + totalOffloadRapidsCaseNumber);
+ }
+
+ public static void addFallBackClassName(String className) {
+ if (!UT_ENV) {
+ return;
+ }
+
+ if (caseInfos.containsKey(currentCase) && !caseInfos.get(currentCase).stack.isEmpty()) {
+ CaseInfo info = caseInfos.get(currentCase);
+ caseInfos.get(currentCase).fallbackExpressionName.add(info.stack.pop());
+ caseInfos.get(currentCase).fallbackClassName.add(className);
+ }
+ }
+
+ public static void addFallBackCase() {
+ if (!UT_ENV) {
+ return;
+ }
+
+ if (caseInfos.containsKey(currentCase)) {
+ caseInfos.get(currentCase).type = "fallback";
+ }
+ }
+
+ public static void addExpressionClassName(String className) {
+ if (!UT_ENV) {
+ return;
+ }
+
+ if (caseInfos.containsKey(currentCase)) {
+ CaseInfo info = caseInfos.get(currentCase);
+ info.stack.add(className);
+ }
+ }
+
+ public static Set getFallBackClassName() {
+ if (!UT_ENV) {
+ return Collections.emptySet();
+ }
+
+ if (caseInfos.containsKey(currentCase)) {
+ return Collections.unmodifiableSet(caseInfos.get(currentCase).fallbackExpressionName);
+ }
+
+ return Collections.emptySet();
+ }
+
+ public static void addIgnoreCaseName(String caseName) {
+ if (!UT_ENV) {
+ return;
+ }
+
+ if (caseInfos.containsKey(caseName)) {
+ caseInfos.get(caseName).type = "fatal";
+ }
+ }
+
+ public static void resetCase() {
+ if (!UT_ENV) {
+ return;
+ }
+
+ if (caseInfos.containsKey(currentCase)) {
+ caseInfos.get(currentCase).stack.clear();
+ }
+ currentCase = "";
+ }
+
+ public static void startCase(String caseName) {
+ if (!UT_ENV) {
+ return;
+ }
+
+ caseInfos.putIfAbsent(caseName, new CaseInfo());
+ currentCase = caseName;
+ }
+
+ public static void endCase(boolean status) {
+ if (!UT_ENV) {
+ return;
+ }
+
+ if (caseInfos.containsKey(currentCase)) {
+ caseInfos.get(currentCase).status = status ? "success" : "error";
+ }
+
+ resetCase();
+ }
+}
+
+class CaseInfo {
+ final Stack stack = new Stack<>();
+ Set fallbackExpressionName = new HashSet<>();
+ Set fallbackClassName = new HashSet<>();
+ String type = "";
+ String status = "";
+}
diff --git a/tests/src/test/spark330/scala/org/apache/spark/sql/rapids/suites/RapidsCastSuite.scala b/tests/src/test/spark330/scala/org/apache/spark/sql/rapids/suites/RapidsCastSuite.scala
new file mode 100644
index 00000000000..f3fec27f7f6
--- /dev/null
+++ b/tests/src/test/spark330/scala/org/apache/spark/sql/rapids/suites/RapidsCastSuite.scala
@@ -0,0 +1,39 @@
+/*
+ * Copyright (c) 2024, NVIDIA CORPORATION.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+/*** spark-rapids-shim-json-lines
+{"spark": "330"}
+spark-rapids-shim-json-lines ***/
+package org.apache.spark.sql.rapids.suites
+
+import org.apache.spark.sql.catalyst.expressions.{Cast, CastBase, CastSuite, Expression, Literal}
+import org.apache.spark.sql.rapids.utils.RapidsTestsTrait
+import org.apache.spark.sql.types._
+
+class RapidsCastSuite extends CastSuite with RapidsTestsTrait {
+ // example to enhance logging for base suite
+ override def cast(v: Any, targetType: DataType, timeZoneId: Option[String] = None): CastBase = {
+ v match {
+ case lit: Expression =>
+ logDebug(s"Cast from: ${lit.dataType.typeName}, to: ${targetType.typeName}")
+ Cast(lit, targetType, timeZoneId)
+ case _ =>
+ val lit = Literal(v)
+ logDebug(s"Cast from: ${lit.dataType.typeName}, to: ${targetType.typeName}")
+ Cast(lit, targetType, timeZoneId)
+ }
+ }
+}
diff --git a/tests/src/test/spark330/scala/org/apache/spark/sql/rapids/suites/RapidsDataFrameAggregateSuite.scala b/tests/src/test/spark330/scala/org/apache/spark/sql/rapids/suites/RapidsDataFrameAggregateSuite.scala
new file mode 100644
index 00000000000..5a394a5b0e8
--- /dev/null
+++ b/tests/src/test/spark330/scala/org/apache/spark/sql/rapids/suites/RapidsDataFrameAggregateSuite.scala
@@ -0,0 +1,30 @@
+/*
+ * Copyright (c) 2024, NVIDIA CORPORATION.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+/*** spark-rapids-shim-json-lines
+{"spark": "330"}
+spark-rapids-shim-json-lines ***/
+package org.apache.spark.sql.rapids.suites
+
+import org.apache.spark.sql.DataFrameAggregateSuite
+import org.apache.spark.sql.rapids.utils.RapidsSQLTestsTrait
+
+class RapidsDataFrameAggregateSuite extends DataFrameAggregateSuite with RapidsSQLTestsTrait {
+ // example to show how to replace the logic of an excluded test case in Vanilla Spark
+ testRapids("collect functions" ) { // "collect functions" was excluded at RapidsTestSettings
+ // println("...")
+ }
+}
diff --git a/tests/src/test/spark330/scala/org/apache/spark/sql/rapids/suites/RapidsJsonFunctionsSuite.scala b/tests/src/test/spark330/scala/org/apache/spark/sql/rapids/suites/RapidsJsonFunctionsSuite.scala
new file mode 100644
index 00000000000..43150c0df4b
--- /dev/null
+++ b/tests/src/test/spark330/scala/org/apache/spark/sql/rapids/suites/RapidsJsonFunctionsSuite.scala
@@ -0,0 +1,25 @@
+/*
+ * Copyright (c) 2024, NVIDIA CORPORATION.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+/*** spark-rapids-shim-json-lines
+{"spark": "330"}
+spark-rapids-shim-json-lines ***/
+package org.apache.spark.sql.rapids.suites
+
+import org.apache.spark.sql.JsonFunctionsSuite
+import org.apache.spark.sql.rapids.utils.RapidsSQLTestsTrait
+
+class RapidsJsonFunctionsSuite extends JsonFunctionsSuite with RapidsSQLTestsTrait {}
diff --git a/tests/src/test/spark330/scala/org/apache/spark/sql/rapids/suites/RapidsJsonSuite.scala b/tests/src/test/spark330/scala/org/apache/spark/sql/rapids/suites/RapidsJsonSuite.scala
new file mode 100644
index 00000000000..6d244c67ad0
--- /dev/null
+++ b/tests/src/test/spark330/scala/org/apache/spark/sql/rapids/suites/RapidsJsonSuite.scala
@@ -0,0 +1,87 @@
+/*
+ * Copyright (c) 2024, NVIDIA CORPORATION.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+/*** spark-rapids-shim-json-lines
+{"spark": "330"}
+spark-rapids-shim-json-lines ***/
+package org.apache.spark.sql.rapids.suites
+
+import org.apache.spark.SparkConf
+import org.apache.spark.sql.execution.datasources.{InMemoryFileIndex, NoopCache}
+import org.apache.spark.sql.execution.datasources.json.JsonSuite
+import org.apache.spark.sql.execution.datasources.v2.json.JsonScanBuilder
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.rapids.utils.RapidsSQLTestsBaseTrait
+import org.apache.spark.sql.sources
+import org.apache.spark.sql.types.{IntegerType, StructType}
+import org.apache.spark.sql.util.CaseInsensitiveStringMap
+
+class RapidsJsonSuite extends JsonSuite with RapidsSQLTestsBaseTrait {
+
+ /** Returns full path to the given file in the resource folder */
+ override protected def testFile(fileName: String): String = {
+ getWorkspaceFilePath("sql", "core", "src", "test", "resources").toString + "/" + fileName
+ }
+}
+
+class RapidsJsonV1Suite extends RapidsJsonSuite with RapidsSQLTestsBaseTrait {
+ override def sparkConf: SparkConf =
+ super.sparkConf
+ .set(SQLConf.USE_V1_SOURCE_LIST, "json")
+}
+
+class RapidsJsonV2Suite extends RapidsJsonSuite with RapidsSQLTestsBaseTrait {
+ override def sparkConf: SparkConf =
+ super.sparkConf
+ .set(SQLConf.USE_V1_SOURCE_LIST, "")
+
+ test("get pushed filters") {
+ val attr = "col"
+ def getBuilder(path: String): JsonScanBuilder = {
+ val fileIndex = new InMemoryFileIndex(
+ spark,
+ Seq(new org.apache.hadoop.fs.Path(path, "file.json")),
+ Map.empty,
+ None,
+ NoopCache)
+ val schema = new StructType().add(attr, IntegerType)
+ val options = CaseInsensitiveStringMap.empty()
+ new JsonScanBuilder(spark, fileIndex, schema, schema, options)
+ }
+ val filters: Array[sources.Filter] = Array(sources.IsNotNull(attr))
+ withSQLConf(SQLConf.JSON_FILTER_PUSHDOWN_ENABLED.key -> "true") {
+ withTempPath {
+ file =>
+ val scanBuilder = getBuilder(file.getCanonicalPath)
+ assert(scanBuilder.pushDataFilters(filters) === filters)
+ }
+ }
+
+ withSQLConf(SQLConf.JSON_FILTER_PUSHDOWN_ENABLED.key -> "false") {
+ withTempPath {
+ file =>
+ val scanBuilder = getBuilder(file.getCanonicalPath)
+ assert(scanBuilder.pushDataFilters(filters) === Array.empty[sources.Filter])
+ }
+ }
+ }
+}
+
+class RapidsJsonLegacyTimeParserSuite extends RapidsJsonSuite with RapidsSQLTestsBaseTrait {
+ override def sparkConf: SparkConf =
+ super.sparkConf
+ .set(SQLConf.LEGACY_TIME_PARSER_POLICY, "legacy")
+}
diff --git a/tests/src/test/spark330/scala/org/apache/spark/sql/rapids/suites/RapidsMathFunctionsSuite.scala b/tests/src/test/spark330/scala/org/apache/spark/sql/rapids/suites/RapidsMathFunctionsSuite.scala
new file mode 100644
index 00000000000..55b4b00f680
--- /dev/null
+++ b/tests/src/test/spark330/scala/org/apache/spark/sql/rapids/suites/RapidsMathFunctionsSuite.scala
@@ -0,0 +1,26 @@
+/*
+ * Copyright (c) 2024, NVIDIA CORPORATION.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+/*** spark-rapids-shim-json-lines
+{"spark": "330"}
+spark-rapids-shim-json-lines ***/
+package org.apache.spark.sql.rapids.suites
+
+import org.apache.spark.sql.MathFunctionsSuite
+import org.apache.spark.sql.rapids.utils.RapidsSQLTestsTrait
+
+class RapidsMathFunctionsSuite extends MathFunctionsSuite with RapidsSQLTestsTrait {
+}
diff --git a/tests/src/test/spark330/scala/org/apache/spark/sql/rapids/suites/RapidsRegexpExpressionsSuite.scala b/tests/src/test/spark330/scala/org/apache/spark/sql/rapids/suites/RapidsRegexpExpressionsSuite.scala
new file mode 100644
index 00000000000..95b54240dbe
--- /dev/null
+++ b/tests/src/test/spark330/scala/org/apache/spark/sql/rapids/suites/RapidsRegexpExpressionsSuite.scala
@@ -0,0 +1,25 @@
+/*
+ * Copyright (c) 2024, NVIDIA CORPORATION.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+/*** spark-rapids-shim-json-lines
+{"spark": "330"}
+spark-rapids-shim-json-lines ***/
+package org.apache.spark.sql.rapids.suites
+
+import org.apache.spark.sql.catalyst.expressions.RegexpExpressionsSuite
+import org.apache.spark.sql.rapids.utils.RapidsTestsTrait
+
+class RapidsRegexpExpressionsSuite extends RegexpExpressionsSuite with RapidsTestsTrait {}
diff --git a/tests/src/test/spark330/scala/org/apache/spark/sql/rapids/suites/RapidsStringExpressionsSuite.scala b/tests/src/test/spark330/scala/org/apache/spark/sql/rapids/suites/RapidsStringExpressionsSuite.scala
new file mode 100644
index 00000000000..164406fdf83
--- /dev/null
+++ b/tests/src/test/spark330/scala/org/apache/spark/sql/rapids/suites/RapidsStringExpressionsSuite.scala
@@ -0,0 +1,25 @@
+/*
+ * Copyright (c) 2024, NVIDIA CORPORATION.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+/*** spark-rapids-shim-json-lines
+{"spark": "330"}
+spark-rapids-shim-json-lines ***/
+package org.apache.spark.sql.rapids.suites
+
+import org.apache.spark.sql.catalyst.expressions.StringExpressionsSuite
+import org.apache.spark.sql.rapids.utils.RapidsTestsTrait
+
+class RapidsStringExpressionsSuite extends StringExpressionsSuite with RapidsTestsTrait {}
diff --git a/tests/src/test/spark330/scala/org/apache/spark/sql/rapids/suites/RapidsStringFunctionsSuite.scala b/tests/src/test/spark330/scala/org/apache/spark/sql/rapids/suites/RapidsStringFunctionsSuite.scala
new file mode 100644
index 00000000000..7b4a8ac6d7d
--- /dev/null
+++ b/tests/src/test/spark330/scala/org/apache/spark/sql/rapids/suites/RapidsStringFunctionsSuite.scala
@@ -0,0 +1,28 @@
+/*
+ * Copyright (c) 2024, NVIDIA CORPORATION.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+/*** spark-rapids-shim-json-lines
+{"spark": "330"}
+spark-rapids-shim-json-lines ***/
+package org.apache.spark.sql.rapids.suites
+
+import org.apache.spark.sql.StringFunctionsSuite
+import org.apache.spark.sql.rapids.utils.RapidsSQLTestsTrait
+
+class RapidsStringFunctionsSuite
+ extends StringFunctionsSuite
+ with RapidsSQLTestsTrait {
+}
diff --git a/tests/src/test/spark330/scala/org/apache/spark/sql/rapids/utils/BackendTestSettings.scala b/tests/src/test/spark330/scala/org/apache/spark/sql/rapids/utils/BackendTestSettings.scala
new file mode 100644
index 00000000000..83396e977fa
--- /dev/null
+++ b/tests/src/test/spark330/scala/org/apache/spark/sql/rapids/utils/BackendTestSettings.scala
@@ -0,0 +1,215 @@
+/*
+ * Copyright (c) 2024, NVIDIA CORPORATION.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+/*** spark-rapids-shim-json-lines
+{"spark": "330"}
+spark-rapids-shim-json-lines ***/
+package org.apache.spark.sql.rapids.utils
+
+import java.util
+
+import scala.collection.JavaConverters._
+import scala.reflect.ClassTag
+
+import com.nvidia.spark.rapids.TestStats
+
+import org.apache.spark.sql.rapids.utils.RapidsTestConstants.RAPIDS_TEST
+
+abstract class BackendTestSettings {
+
+ private val enabledSuites: java.util.Map[String, SuiteSettings] = new util.HashMap()
+
+ protected def enableSuite[T: ClassTag]: SuiteSettings = {
+ val suiteName = implicitly[ClassTag[T]].runtimeClass.getCanonicalName
+ if (enabledSuites.containsKey(suiteName)) {
+ throw new IllegalArgumentException("Duplicated suite name: " + suiteName)
+ }
+ val suiteSettings = new SuiteSettings
+ enabledSuites.put(suiteName, suiteSettings)
+ suiteSettings
+ }
+
+ private[utils] def shouldRun(suiteName: String, testName: String): Boolean = {
+ if (!enabledSuites.containsKey(suiteName)) {
+ return false
+ }
+
+ val suiteSettings = enabledSuites.get(suiteName)
+
+ val inclusion = suiteSettings.inclusion.asScala
+ val exclusion = suiteSettings.exclusion.asScala
+
+ if (inclusion.isEmpty && exclusion.isEmpty) {
+ // default to run all cases under this suite
+ return true
+ }
+
+ if (inclusion.nonEmpty && exclusion.nonEmpty) {
+ // error
+ throw new IllegalStateException(
+ s"Do not use include and exclude conditions on the same test case: $suiteName:$testName")
+ }
+
+ if (inclusion.nonEmpty) {
+ // include mode
+ val isIncluded = inclusion.exists(_.isIncluded(testName))
+ return isIncluded
+ }
+
+ if (exclusion.nonEmpty) {
+ // exclude mode
+ val isExcluded = exclusion.exists(_.isExcluded(testName))
+ return !isExcluded
+ }
+
+ throw new IllegalStateException("Unreachable code")
+ }
+
+ sealed trait ExcludeReason
+ // The reason should most likely to be a issue link,
+ // or a description like "This simply can't work on GPU".
+ // It should never be "unknown" or "need investigation"
+ case class KNOWN_ISSUE(reason: String) extends ExcludeReason
+ case class WONT_FIX_ISSUE(reason: String) extends ExcludeReason
+
+
+ final protected class SuiteSettings {
+ private[utils] val inclusion: util.List[IncludeBase] = new util.ArrayList()
+ private[utils] val exclusion: util.List[ExcludeBase] = new util.ArrayList()
+ private[utils] val excludeReasons: util.List[ExcludeReason] = new util.ArrayList()
+
+ def include(testNames: String*): SuiteSettings = {
+ inclusion.add(Include(testNames: _*))
+ this
+ }
+ def exclude(testNames: String, reason: ExcludeReason): SuiteSettings = {
+ exclusion.add(Exclude(testNames))
+ excludeReasons.add(reason)
+ this
+ }
+ def includeRapidsTest(testName: String*): SuiteSettings = {
+ inclusion.add(IncludeRapidsTest(testName: _*))
+ this
+ }
+ def excludeRapidsTest(testName: String, reason: ExcludeReason): SuiteSettings = {
+ exclusion.add(ExcludeRapidsTest(testName))
+ excludeReasons.add(reason)
+ this
+ }
+ def includeByPrefix(prefixes: String*): SuiteSettings = {
+ inclusion.add(IncludeByPrefix(prefixes: _*))
+ this
+ }
+ def excludeByPrefix(prefixes: String, reason: ExcludeReason): SuiteSettings = {
+ exclusion.add(ExcludeByPrefix(prefixes))
+ excludeReasons.add(reason)
+ this
+ }
+ def includeRapidsTestsByPrefix(prefixes: String*): SuiteSettings = {
+ inclusion.add(IncludeRapidsTestByPrefix(prefixes: _*))
+ this
+ }
+ def excludeRapidsTestsByPrefix(prefixes: String, reason: ExcludeReason): SuiteSettings = {
+ exclusion.add(ExcludeRadpisTestByPrefix(prefixes))
+ excludeReasons.add(reason)
+ this
+ }
+ def includeAllRapidsTests(): SuiteSettings = {
+ inclusion.add(IncludeByPrefix(RAPIDS_TEST))
+ this
+ }
+ def excludeAllRapidsTests(reason: ExcludeReason): SuiteSettings = {
+ exclusion.add(ExcludeByPrefix(RAPIDS_TEST))
+ excludeReasons.add(reason)
+ this
+ }
+ }
+
+ protected trait IncludeBase {
+ def isIncluded(testName: String): Boolean
+ }
+ protected trait ExcludeBase {
+ def isExcluded(testName: String): Boolean
+ }
+ private case class Include(testNames: String*) extends IncludeBase {
+ val nameSet: Set[String] = Set(testNames: _*)
+ override def isIncluded(testName: String): Boolean = nameSet.contains(testName)
+ }
+ private case class Exclude(testNames: String*) extends ExcludeBase {
+ val nameSet: Set[String] = Set(testNames: _*)
+ override def isExcluded(testName: String): Boolean = nameSet.contains(testName)
+ }
+ private case class IncludeRapidsTest(testNames: String*) extends IncludeBase {
+ val nameSet: Set[String] = testNames.map(name => RAPIDS_TEST + name).toSet
+ override def isIncluded(testName: String): Boolean = nameSet.contains(testName)
+ }
+ private case class ExcludeRapidsTest(testNames: String*) extends ExcludeBase {
+ val nameSet: Set[String] = testNames.map(name => RAPIDS_TEST + name).toSet
+ override def isExcluded(testName: String): Boolean = nameSet.contains(testName)
+ }
+ private case class IncludeByPrefix(prefixes: String*) extends IncludeBase {
+ override def isIncluded(testName: String): Boolean = {
+ if (prefixes.exists(prefix => testName.startsWith(prefix))) {
+ return true
+ }
+ false
+ }
+ }
+ private case class ExcludeByPrefix(prefixes: String*) extends ExcludeBase {
+ override def isExcluded(testName: String): Boolean = {
+ if (prefixes.exists(prefix => testName.startsWith(prefix))) {
+ return true
+ }
+ false
+ }
+ }
+ private case class IncludeRapidsTestByPrefix(prefixes: String*) extends IncludeBase {
+ override def isIncluded(testName: String): Boolean = {
+ if (prefixes.exists(prefix => testName.startsWith(RAPIDS_TEST + prefix))) {
+ return true
+ }
+ false
+ }
+ }
+ private case class ExcludeRadpisTestByPrefix(prefixes: String*) extends ExcludeBase {
+ override def isExcluded(testName: String): Boolean = {
+ if (prefixes.exists(prefix => testName.startsWith(RAPIDS_TEST + prefix))) {
+ return true
+ }
+ false
+ }
+ }
+}
+
+object BackendTestSettings {
+ val instance: BackendTestSettings = {
+ Class
+ .forName("org.apache.spark.sql.rapids.utils.RapidsTestSettings")
+ .getDeclaredConstructor()
+ .newInstance()
+ .asInstanceOf[BackendTestSettings]
+ }
+
+ def shouldRun(suiteName: String, testName: String): Boolean = {
+ val v = instance.shouldRun(suiteName, testName: String)
+
+ if (!v) {
+ TestStats.addIgnoreCaseName(testName)
+ }
+
+ v
+ }
+}
diff --git a/tests/src/test/spark330/scala/org/apache/spark/sql/rapids/utils/RapidsSQLTestsBaseTrait.scala b/tests/src/test/spark330/scala/org/apache/spark/sql/rapids/utils/RapidsSQLTestsBaseTrait.scala
new file mode 100644
index 00000000000..540c70a2ee1
--- /dev/null
+++ b/tests/src/test/spark330/scala/org/apache/spark/sql/rapids/utils/RapidsSQLTestsBaseTrait.scala
@@ -0,0 +1,128 @@
+/*
+ * Copyright (c) 2024, NVIDIA CORPORATION.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+/*** spark-rapids-shim-json-lines
+{"spark": "330"}
+spark-rapids-shim-json-lines ***/
+package org.apache.spark.sql.rapids.utils
+
+import java.util.{Locale, TimeZone}
+
+import org.scalactic.source.Position
+import org.scalatest.Tag
+
+import org.apache.spark.SparkConf
+import org.apache.spark.internal.config.Tests.IS_TESTING
+import org.apache.spark.sql.DataFrame
+import org.apache.spark.sql.execution.SparkPlan
+import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanExec, ShuffleQueryStageExec}
+import org.apache.spark.sql.rapids.utils.RapidsTestConstants.RAPIDS_TEST
+import org.apache.spark.sql.test.SharedSparkSession
+
+
+/** Basic trait for Rapids SQL test cases. */
+trait RapidsSQLTestsBaseTrait extends SharedSparkSession with RapidsTestsBaseTrait {
+
+ protected override def afterAll(): Unit = {
+ // SparkFunSuite will set this to true, and forget to reset to false
+ System.clearProperty(IS_TESTING.key)
+ super.afterAll()
+ }
+
+ protected def testRapids(testName: String, testTag: Tag*)(testFun: => Any)(implicit
+ pos: Position): Unit = {
+ test(RAPIDS_TEST + testName, testTag: _*)(testFun)
+ }
+
+ override protected def test(testName: String, testTags: Tag*)(testFun: => Any)(implicit
+ pos: Position): Unit = {
+ if (shouldRun(testName)) {
+ super.test(testName, testTags: _*)(testFun)
+ } else {
+ super.ignore(testName, testTags: _*)(testFun)
+ }
+ }
+
+ override def sparkConf: SparkConf = {
+ RapidsSQLTestsBaseTrait.nativeSparkConf(super.sparkConf, warehouse)
+ }
+
+ /**
+ * Get all the children plan of plans.
+ *
+ * @param plans
+ * : the input plans.
+ * @return
+ */
+ private def getChildrenPlan(plans: Seq[SparkPlan]): Seq[SparkPlan] = {
+ if (plans.isEmpty) {
+ return Seq()
+ }
+
+ val inputPlans: Seq[SparkPlan] = plans.map {
+ case stage: ShuffleQueryStageExec => stage.plan
+ case plan => plan
+ }
+
+ var newChildren: Seq[SparkPlan] = Seq()
+ inputPlans.foreach {
+ plan =>
+ newChildren = newChildren ++ getChildrenPlan(plan.children)
+ // To avoid duplication of WholeStageCodegenXXX and its children.
+ if (!plan.nodeName.startsWith("WholeStageCodegen")) {
+ newChildren = newChildren :+ plan
+ }
+ }
+ newChildren
+ }
+
+ /**
+ * Get the executed plan of a data frame.
+ *
+ * @param df
+ * : dataframe.
+ * @return
+ * A sequence of executed plans.
+ */
+ def getExecutedPlan(df: DataFrame): Seq[SparkPlan] = {
+ df.queryExecution.executedPlan match {
+ case exec: AdaptiveSparkPlanExec =>
+ getChildrenPlan(Seq(exec.executedPlan))
+ case plan =>
+ getChildrenPlan(Seq(plan))
+ }
+ }
+}
+
+object RapidsSQLTestsBaseTrait {
+ def nativeSparkConf(origin: SparkConf, warehouse: String): SparkConf = {
+ // Timezone is fixed to UTC to allow timestamps to work by default
+ TimeZone.setDefault(TimeZone.getTimeZone("UTC"))
+ // Add Locale setting
+ Locale.setDefault(Locale.US)
+
+ val conf = origin
+ .set("spark.rapids.sql.enabled", "true")
+ .set("spark.plugins", "com.nvidia.spark.SQLPlugin")
+ .set("spark.sql.queryExecutionListeners",
+ "org.apache.spark.sql.rapids.ExecutionPlanCaptureCallback")
+ .set("spark.sql.warehouse.dir", warehouse)
+ .set("spark.sql.cache.serializer", "com.nvidia.spark.ParquetCachedBatchSerializer")
+ .setAppName("rapids spark plugin running Vanilla Spark UT")
+
+ conf
+ }
+}
diff --git a/tests/src/test/spark330/scala/org/apache/spark/sql/rapids/utils/RapidsSQLTestsTrait.scala b/tests/src/test/spark330/scala/org/apache/spark/sql/rapids/utils/RapidsSQLTestsTrait.scala
new file mode 100644
index 00000000000..4358e29630c
--- /dev/null
+++ b/tests/src/test/spark330/scala/org/apache/spark/sql/rapids/utils/RapidsSQLTestsTrait.scala
@@ -0,0 +1,273 @@
+/*
+ * Copyright (c) 2024, NVIDIA CORPORATION.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+/*** spark-rapids-shim-json-lines
+{"spark": "330"}
+spark-rapids-shim-json-lines ***/
+package org.apache.spark.sql.rapids.utils
+
+import java.io.File
+import java.util.TimeZone
+
+import scala.collection.JavaConverters._
+
+import org.apache.commons.io.{FileUtils => fu}
+import org.apache.commons.math3.util.Precision
+import org.scalatest.Assertions
+
+import org.apache.spark.sql.{AnalysisException, DataFrame, QueryTest, Row}
+import org.apache.spark.sql.catalyst.plans.logical
+import org.apache.spark.sql.catalyst.util.{sideBySide, stackTraceToString}
+import org.apache.spark.sql.execution.SQLExecution
+
+
+/** Basic trait for Rapids SQL test cases. */
+trait RapidsSQLTestsTrait extends QueryTest with RapidsSQLTestsBaseTrait {
+
+ def prepareWorkDir(): Unit = {
+ // prepare working paths
+ val basePathDir = new File(basePath)
+ if (basePathDir.exists()) {
+ fu.forceDelete(basePathDir)
+ }
+ fu.forceMkdir(basePathDir)
+ fu.forceMkdir(new File(warehouse))
+ fu.forceMkdir(new File(metaStorePathAbsolute))
+ }
+
+ override def beforeAll(): Unit = {
+ prepareWorkDir()
+ super.beforeAll()
+ spark.sparkContext.setLogLevel("WARN")
+ }
+
+ override def afterAll(): Unit = {
+ super.afterAll()
+ }
+
+ override protected def checkAnswer(df: => DataFrame, expectedAnswer: Seq[Row]): Unit = {
+ val analyzedDF =
+ try df
+ catch {
+ case ae: AnalysisException =>
+ val plan = ae.plan
+ if (plan.isDefined) {
+ fail(s"""
+ |Failed to analyze query: $ae
+ |${plan.get}
+ |
+ |${stackTraceToString(ae)}
+ |""".stripMargin)
+ } else {
+ throw ae
+ }
+ }
+
+ assertEmptyMissingInput(analyzedDF)
+
+ RapidsQueryTestUtil.checkAnswer(analyzedDF, expectedAnswer)
+ }
+}
+
+object RapidsQueryTestUtil extends Assertions {
+
+ /**
+ * Runs the plan and makes sure the answer matches the expected result.
+ *
+ * @param df
+ * the DataFrame to be executed
+ * @param expectedAnswer
+ * the expected result in a Seq of Rows.
+ * @param checkToRDD
+ * whether to verify deserialization to an RDD. This runs the query twice.
+ */
+ def checkAnswer(df: DataFrame, expectedAnswer: Seq[Row], checkToRDD: Boolean = true): Unit = {
+ getErrorMessageInCheckAnswer(df, expectedAnswer, checkToRDD) match {
+ case Some(errorMessage) => fail(errorMessage)
+ case None =>
+ }
+ }
+
+ /**
+ * Runs the plan and makes sure the answer matches the expected result. If there was exception
+ * during the execution or the contents of the DataFrame does not match the expected result, an
+ * error message will be returned. Otherwise, a None will be returned.
+ *
+ * @param df
+ * the DataFrame to be executed
+ * @param expectedAnswer
+ * the expected result in a Seq of Rows.
+ * @param checkToRDD
+ * whether to verify deserialization to an RDD. This runs the query twice.
+ */
+ def getErrorMessageInCheckAnswer(
+ df: DataFrame,
+ expectedAnswer: Seq[Row],
+ checkToRDD: Boolean = true): Option[String] = {
+ val isSorted = df.logicalPlan.collect { case s: logical.Sort => s }.nonEmpty
+ if (checkToRDD) {
+ SQLExecution.withSQLConfPropagated(df.sparkSession) {
+ df.rdd.count() // Also attempt to deserialize as an RDD [SPARK-15791]
+ }
+ }
+
+ val sparkAnswer =
+ try df.collect().toSeq
+ catch {
+ case e: Exception =>
+ val errorMessage =
+ s"""
+ |Exception thrown while executing query:
+ |${df.queryExecution}
+ |== Exception ==
+ |$e
+ |${org.apache.spark.sql.catalyst.util.stackTraceToString(e)}
+ """.stripMargin
+ return Some(errorMessage)
+ }
+
+ sameRows(expectedAnswer, sparkAnswer, isSorted).map {
+ results =>
+ s"""
+ |Results do not match for query:
+ |Timezone: ${TimeZone.getDefault}
+ |Timezone Env: ${sys.env.getOrElse("TZ", "")}
+ |
+ |${df.queryExecution}
+ |== Results ==
+ |$results
+ """.stripMargin
+ }
+ }
+
+ def prepareAnswer(answer: Seq[Row], isSorted: Boolean): Seq[Row] = {
+ // Converts data to types that we can do equality comparison using Scala collections.
+ // For BigDecimal type, the Scala type has a better definition of equality test (similar to
+ // Java's java.math.BigDecimal.compareTo).
+ // For binary arrays, we convert it to Seq to avoid of calling java.util.Arrays.equals for
+ // equality test.
+ val converted: Seq[Row] = answer.map(prepareRow)
+ if (!isSorted) converted.sortBy(_.toString()) else converted
+ }
+
+ // We need to call prepareRow recursively to handle schemas with struct types.
+ def prepareRow(row: Row): Row = {
+ Row.fromSeq(row.toSeq.map {
+ case null => null
+ case bd: java.math.BigDecimal => BigDecimal(bd)
+ // Equality of WrappedArray differs for AnyVal and AnyRef in Scala 2.12.2+
+ case seq: Seq[_] =>
+ seq.map {
+ case b: java.lang.Byte => b.byteValue
+ case s: java.lang.Short => s.shortValue
+ case i: java.lang.Integer => i.intValue
+ case l: java.lang.Long => l.longValue
+ case f: java.lang.Float => f.floatValue
+ case d: java.lang.Double => d.doubleValue
+ case x => x
+ }
+ // Convert array to Seq for easy equality check.
+ case b: Array[_] => b.toSeq
+ case r: Row => prepareRow(r)
+ case o => o
+ })
+ }
+
+ private def genError(
+ expectedAnswer: Seq[Row],
+ sparkAnswer: Seq[Row],
+ isSorted: Boolean): String = {
+ val getRowType: Option[Row] => String = row =>
+ row
+ .map(
+ row =>
+ if (row.schema == null) {
+ "struct<>"
+ } else {
+ s"${row.schema.catalogString}"
+ })
+ .getOrElse("struct<>")
+
+ s"""
+ |== Results ==
+ |${sideBySide(
+ s"== Correct Answer - ${expectedAnswer.size} ==" +:
+ getRowType(expectedAnswer.headOption) +:
+ prepareAnswer(expectedAnswer, isSorted).map(_.toString()),
+ s"== RAPIDS Answer - ${sparkAnswer.size} ==" +:
+ getRowType(sparkAnswer.headOption) +:
+ prepareAnswer(sparkAnswer, isSorted).map(_.toString())
+ ).mkString("\n")}
+ """.stripMargin
+ }
+
+ def includesRows(expectedRows: Seq[Row], sparkAnswer: Seq[Row]): Option[String] = {
+ if (!prepareAnswer(expectedRows, true).toSet.subsetOf(prepareAnswer(sparkAnswer, true).toSet)) {
+ return Some(genError(expectedRows, sparkAnswer, true))
+ }
+ None
+ }
+
+ private def compare(obj1: Any, obj2: Any): Boolean = (obj1, obj2) match {
+ case (null, null) => true
+ case (null, _) => false
+ case (_, null) => false
+ case (a: Array[_], b: Array[_]) =>
+ a.length == b.length && a.zip(b).forall { case (l, r) => compare(l, r) }
+ case (a: Map[_, _], b: Map[_, _]) =>
+ a.size == b.size && a.keys.forall {
+ aKey => b.keys.find(bKey => compare(aKey, bKey)).exists(bKey => compare(a(aKey), b(bKey)))
+ }
+ case (a: Iterable[_], b: Iterable[_]) =>
+ a.size == b.size && a.zip(b).forall { case (l, r) => compare(l, r) }
+ case (a: Product, b: Product) =>
+ compare(a.productIterator.toSeq, b.productIterator.toSeq)
+ case (a: Row, b: Row) =>
+ compare(a.toSeq, b.toSeq)
+ // 0.0 == -0.0, turn float/double to bits before comparison, to distinguish 0.0 and -0.0.
+ case (a: Double, b: Double) =>
+ if ((isNaNOrInf(a) || isNaNOrInf(b)) || (a == -0.0) || (b == -0.0)) {
+ java.lang.Double.doubleToRawLongBits(a) == java.lang.Double.doubleToRawLongBits(b)
+ } else {
+ Precision.equalsWithRelativeTolerance(a, b, 0.00001d)
+ }
+ case (a: Float, b: Float) =>
+ java.lang.Float.floatToRawIntBits(a) == java.lang.Float.floatToRawIntBits(b)
+ case (a, b) => a == b
+ }
+
+ def isNaNOrInf(num: Double): Boolean = {
+ num.isNaN || num.isInfinite || num.isNegInfinity || num.isPosInfinity
+ }
+
+ def sameRows(
+ expectedAnswer: Seq[Row],
+ sparkAnswer: Seq[Row],
+ isSorted: Boolean = false): Option[String] = {
+ // modify method 'compare'
+ if (!compare(prepareAnswer(expectedAnswer, isSorted), prepareAnswer(sparkAnswer, isSorted))) {
+ return Some(genError(expectedAnswer, sparkAnswer, isSorted))
+ }
+ None
+ }
+
+ def checkAnswer(df: DataFrame, expectedAnswer: java.util.List[Row]): Unit = {
+ getErrorMessageInCheckAnswer(df, expectedAnswer.asScala.toSeq) match {
+ case Some(errorMessage) => fail(errorMessage)
+ case None =>
+ }
+ }
+}
diff --git a/tests/src/test/spark330/scala/org/apache/spark/sql/rapids/utils/RapidsTestConstants.scala b/tests/src/test/spark330/scala/org/apache/spark/sql/rapids/utils/RapidsTestConstants.scala
new file mode 100644
index 00000000000..772becaa5f9
--- /dev/null
+++ b/tests/src/test/spark330/scala/org/apache/spark/sql/rapids/utils/RapidsTestConstants.scala
@@ -0,0 +1,47 @@
+/*
+ * Copyright (c) 2024, NVIDIA CORPORATION.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+/*** spark-rapids-shim-json-lines
+{"spark": "330"}
+spark-rapids-shim-json-lines ***/
+package org.apache.spark.sql.rapids.utils
+
+import org.apache.spark.sql.types._
+
+object RapidsTestConstants {
+
+ val RAPIDS_TEST: String = "Rapids - "
+
+ val IGNORE_ALL: String = "IGNORE_ALL"
+
+ val SUPPORTED_DATA_TYPE = TypeCollection(
+ BooleanType,
+ ByteType,
+ ShortType,
+ IntegerType,
+ LongType,
+ FloatType,
+ DoubleType,
+ DecimalType,
+ StringType,
+ BinaryType,
+ DateType,
+ TimestampType,
+ ArrayType,
+ StructType,
+ MapType
+ )
+}
diff --git a/tests/src/test/spark330/scala/org/apache/spark/sql/rapids/utils/RapidsTestSettings.scala b/tests/src/test/spark330/scala/org/apache/spark/sql/rapids/utils/RapidsTestSettings.scala
new file mode 100644
index 00000000000..4981c385219
--- /dev/null
+++ b/tests/src/test/spark330/scala/org/apache/spark/sql/rapids/utils/RapidsTestSettings.scala
@@ -0,0 +1,81 @@
+/*
+ * Copyright (c) 2024, NVIDIA CORPORATION.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+/*** spark-rapids-shim-json-lines
+{"spark": "330"}
+spark-rapids-shim-json-lines ***/
+package org.apache.spark.sql.rapids.utils
+
+import org.apache.spark.sql.rapids.suites.{RapidsCastSuite, RapidsDataFrameAggregateSuite, RapidsJsonFunctionsSuite, RapidsJsonSuite, RapidsMathFunctionsSuite, RapidsRegexpExpressionsSuite, RapidsStringExpressionsSuite, RapidsStringFunctionsSuite}
+
+// Some settings' line length exceeds 100
+// scalastyle:off line.size.limit
+
+class RapidsTestSettings extends BackendTestSettings {
+
+ enableSuite[RapidsCastSuite]
+ .exclude("Process Infinity, -Infinity, NaN in case insensitive manner", KNOWN_ISSUE("https://github.com/NVIDIA/spark-rapids/issues/10771"))
+ .exclude("SPARK-35711: cast timestamp without time zone to timestamp with local time zone", KNOWN_ISSUE("https://github.com/NVIDIA/spark-rapids/issues/10771"))
+ .exclude("SPARK-35719: cast timestamp with local time zone to timestamp without timezone", KNOWN_ISSUE("https://github.com/NVIDIA/spark-rapids/issues/10771"))
+ .exclude("SPARK-35112: Cast string to day-time interval", KNOWN_ISSUE("https://github.com/NVIDIA/spark-rapids/issues/10771"))
+ .exclude("SPARK-35735: Take into account day-time interval fields in cast", KNOWN_ISSUE("https://github.com/NVIDIA/spark-rapids/issues/10771"))
+ .exclude("casting to fixed-precision decimals", KNOWN_ISSUE("https://github.com/NVIDIA/spark-rapids/issues/10771"))
+ .exclude("SPARK-32828: cast from a derived user-defined type to a base type", KNOWN_ISSUE("https://github.com/NVIDIA/spark-rapids/issues/10771"))
+ enableSuite[RapidsDataFrameAggregateSuite]
+ .exclude("collect functions", KNOWN_ISSUE("https://github.com/NVIDIA/spark-rapids/issues/10772"))
+ .exclude("collect functions structs", KNOWN_ISSUE("https://github.com/NVIDIA/spark-rapids/issues/10772"))
+ .exclude("collect functions should be able to cast to array type with no null values", KNOWN_ISSUE("https://github.com/NVIDIA/spark-rapids/issues/10772"))
+ .exclude("SPARK-17641: collect functions should not collect null values", KNOWN_ISSUE("https://github.com/NVIDIA/spark-rapids/issues/10772"))
+ .exclude("SPARK-19471: AggregationIterator does not initialize the generated result projection before using it", KNOWN_ISSUE("https://github.com/NVIDIA/spark-rapids/issues/10772"))
+ enableSuite[RapidsJsonFunctionsSuite]
+ enableSuite[RapidsJsonSuite]
+ .exclude("Casting long as timestamp", KNOWN_ISSUE("https://github.com/NVIDIA/spark-rapids/issues/10773"))
+ .exclude("Write timestamps correctly with timestampFormat option and timeZone option", KNOWN_ISSUE("https://github.com/NVIDIA/spark-rapids/issues/10773"))
+ .exclude("SPARK-23723: json in UTF-16 with BOM", KNOWN_ISSUE("https://github.com/NVIDIA/spark-rapids/issues/10773"))
+ .exclude("SPARK-23723: multi-line json in UTF-32BE with BOM", KNOWN_ISSUE("https://github.com/NVIDIA/spark-rapids/issues/10773"))
+ .exclude("SPARK-23723: Use user's encoding in reading of multi-line json in UTF-16LE", KNOWN_ISSUE("https://github.com/NVIDIA/spark-rapids/issues/10773"))
+ .exclude("SPARK-23723: Unsupported encoding name", KNOWN_ISSUE("https://github.com/NVIDIA/spark-rapids/issues/10773"))
+ .exclude("SPARK-23723: checking that the encoding option is case agnostic", KNOWN_ISSUE("https://github.com/NVIDIA/spark-rapids/issues/10773"))
+ .exclude("SPARK-23723: specified encoding is not matched to actual encoding", KNOWN_ISSUE("https://github.com/NVIDIA/spark-rapids/issues/10773"))
+ .exclude("SPARK-23724: lineSep should be set if encoding if different from UTF-8", KNOWN_ISSUE("https://github.com/NVIDIA/spark-rapids/issues/10773"))
+ .exclude("SPARK-31716: inferring should handle malformed input", KNOWN_ISSUE("https://github.com/NVIDIA/spark-rapids/issues/10773"))
+ .exclude("SPARK-24190: restrictions for JSONOptions in read", KNOWN_ISSUE("https://github.com/NVIDIA/spark-rapids/issues/10773"))
+ .exclude("exception mode for parsing date/timestamp string", KNOWN_ISSUE("https://github.com/NVIDIA/spark-rapids/issues/10773"))
+ .exclude("SPARK-37360: Timestamp type inference for a mix of TIMESTAMP_NTZ and TIMESTAMP_LTZ", KNOWN_ISSUE("https://github.com/NVIDIA/spark-rapids/issues/10773"))
+ enableSuite[RapidsMathFunctionsSuite]
+ enableSuite[RapidsRegexpExpressionsSuite]
+ .exclude("RegexReplace", KNOWN_ISSUE("https://github.com/NVIDIA/spark-rapids/issues/10774"))
+ .exclude("RegexExtract", KNOWN_ISSUE("https://github.com/NVIDIA/spark-rapids/issues/10774"))
+ .exclude("RegexExtractAll", KNOWN_ISSUE("https://github.com/NVIDIA/spark-rapids/issues/10774"))
+ .exclude("SPLIT", KNOWN_ISSUE("https://github.com/NVIDIA/spark-rapids/issues/10774"))
+ enableSuite[RapidsStringExpressionsSuite]
+ .exclude("SPARK-22498: Concat should not generate codes beyond 64KB", KNOWN_ISSUE("https://github.com/NVIDIA/spark-rapids/issues/10775"))
+ .exclude("SPARK-22549: ConcatWs should not generate codes beyond 64KB", KNOWN_ISSUE("https://github.com/NVIDIA/spark-rapids/issues/10775"))
+ .exclude("SPARK-22550: Elt should not generate codes beyond 64KB", KNOWN_ISSUE("https://github.com/NVIDIA/spark-rapids/issues/10775"))
+ .exclude("StringComparison", KNOWN_ISSUE("https://github.com/NVIDIA/spark-rapids/issues/10775"))
+ .exclude("Substring", KNOWN_ISSUE("https://github.com/NVIDIA/spark-rapids/issues/10775"))
+ .exclude("ascii for string", KNOWN_ISSUE("https://github.com/NVIDIA/spark-rapids/issues/10775"))
+ .exclude("base64/unbase64 for string", KNOWN_ISSUE("https://github.com/NVIDIA/spark-rapids/issues/10775"))
+ .exclude("encode/decode for string", KNOWN_ISSUE("https://github.com/NVIDIA/spark-rapids/issues/10775"))
+ .exclude("SPARK-22603: FormatString should not generate codes beyond 64KB", KNOWN_ISSUE("https://github.com/NVIDIA/spark-rapids/issues/10775"))
+ .exclude("LOCATE", KNOWN_ISSUE("https://github.com/NVIDIA/spark-rapids/issues/10775"))
+ .exclude("LPAD/RPAD", KNOWN_ISSUE("https://github.com/NVIDIA/spark-rapids/issues/10775"))
+ .exclude("REPEAT", KNOWN_ISSUE("https://github.com/NVIDIA/spark-rapids/issues/10775"))
+ .exclude("length for string / binary", KNOWN_ISSUE("https://github.com/NVIDIA/spark-rapids/issues/10775"))
+ .exclude("ParseUrl", KNOWN_ISSUE("https://github.com/NVIDIA/spark-rapids/issues/10775"))
+ enableSuite[RapidsStringFunctionsSuite]
+}
+// scalastyle:on line.size.limit
diff --git a/tests/src/test/spark330/scala/org/apache/spark/sql/rapids/utils/RapidsTestsBaseTrait.scala b/tests/src/test/spark330/scala/org/apache/spark/sql/rapids/utils/RapidsTestsBaseTrait.scala
new file mode 100644
index 00000000000..a3039077d90
--- /dev/null
+++ b/tests/src/test/spark330/scala/org/apache/spark/sql/rapids/utils/RapidsTestsBaseTrait.scala
@@ -0,0 +1,33 @@
+/*
+ * Copyright (c) 2024, NVIDIA CORPORATION.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+/*** spark-rapids-shim-json-lines
+{"spark": "330"}
+spark-rapids-shim-json-lines ***/
+package org.apache.spark.sql.rapids.utils
+
+trait RapidsTestsBaseTrait {
+
+ protected val rootPath: String = getClass.getResource("/").getPath
+ protected val basePath: String = rootPath + "unit-tests-working-home"
+
+ protected val warehouse: String = basePath + "/spark-warehouse"
+ protected val metaStorePathAbsolute: String = basePath + "/meta"
+
+ def shouldRun(testName: String): Boolean = {
+ BackendTestSettings.shouldRun(getClass.getCanonicalName, testName)
+ }
+}
diff --git a/tests/src/test/spark330/scala/org/apache/spark/sql/rapids/utils/RapidsTestsCommonTrait.scala b/tests/src/test/spark330/scala/org/apache/spark/sql/rapids/utils/RapidsTestsCommonTrait.scala
new file mode 100644
index 00000000000..1b39073fdcf
--- /dev/null
+++ b/tests/src/test/spark330/scala/org/apache/spark/sql/rapids/utils/RapidsTestsCommonTrait.scala
@@ -0,0 +1,73 @@
+/*
+ * Copyright (c) 2024, NVIDIA CORPORATION.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+/*** spark-rapids-shim-json-lines
+{"spark": "330"}
+spark-rapids-shim-json-lines ***/
+package org.apache.spark.sql.rapids.utils
+
+import com.nvidia.spark.rapids.TestStats
+import org.scalactic.source.Position
+import org.scalatest.{Args, Status, Tag}
+
+import org.apache.spark.SparkFunSuite
+import org.apache.spark.internal.config.Tests.IS_TESTING
+import org.apache.spark.sql.catalyst.expressions.ExpressionEvalHelper
+import org.apache.spark.sql.rapids.utils.RapidsTestConstants.RAPIDS_TEST
+
+trait RapidsTestsCommonTrait
+ extends SparkFunSuite
+ with ExpressionEvalHelper
+ with RapidsTestsBaseTrait {
+
+ protected override def afterAll(): Unit = {
+ // SparkFunSuite will set this to true, and forget to reset to false
+ System.clearProperty(IS_TESTING.key)
+ super.afterAll()
+ }
+
+ override def runTest(testName: String, args: Args): Status = {
+ TestStats.suiteTestNumber += 1
+ TestStats.offloadRapids = true
+ TestStats.startCase(testName)
+ val status = super.runTest(testName, args)
+ if (TestStats.offloadRapids) {
+ TestStats.offloadRapidsTestNumber += 1
+ print("'" + testName + "'" + " offload to RAPIDS\n")
+ } else {
+ // you can find the keyword 'Validation failed for' in function doValidate() in log
+ // to get the fallback reason
+ print("'" + testName + "'" + " NOT use RAPIDS\n")
+ TestStats.addFallBackCase()
+ }
+
+ TestStats.endCase(status.succeeds());
+ status
+ }
+
+ protected def testRapids(testName: String, testTag: Tag*)(testFun: => Any)(implicit
+ pos: Position): Unit = {
+ test(RAPIDS_TEST + testName, testTag: _*)(testFun)
+ }
+ override protected def test(testName: String, testTags: Tag*)(testFun: => Any)(implicit
+ pos: Position): Unit = {
+ if (shouldRun(testName)) {
+ super.test(testName, testTags: _*)(testFun)
+ } else {
+ super.ignore(testName, testTags: _*)(testFun)
+ }
+ }
+}
diff --git a/tests/src/test/spark330/scala/org/apache/spark/sql/rapids/utils/RapidsTestsTrait.scala b/tests/src/test/spark330/scala/org/apache/spark/sql/rapids/utils/RapidsTestsTrait.scala
new file mode 100644
index 00000000000..08e10f4d4fd
--- /dev/null
+++ b/tests/src/test/spark330/scala/org/apache/spark/sql/rapids/utils/RapidsTestsTrait.scala
@@ -0,0 +1,335 @@
+/*
+ * Copyright (c) 2024, NVIDIA CORPORATION.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+/*** spark-rapids-shim-json-lines
+{"spark": "330"}
+spark-rapids-shim-json-lines ***/
+package org.apache.spark.sql.rapids.utils
+
+import java.io.File
+
+import com.nvidia.spark.rapids.{GpuProjectExec, TestStats}
+import org.apache.commons.io.{FileUtils => fu}
+import org.apache.commons.math3.util.Precision
+import org.scalactic.TripleEqualsSupport.Spread
+import scala.collection.mutable
+import scala.collection.mutable.ArrayBuffer
+
+import org.apache.spark.sql.{Column, DataFrame, Row, SparkSession}
+import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow}
+import org.apache.spark.sql.catalyst.analysis.ResolveTimeZone
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection
+import org.apache.spark.sql.catalyst.optimizer.{ConstantFolding, ConvertToLocalRelation, NullPropagation}
+import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData, MapData, TypeUtils}
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.rapids.utils.RapidsQueryTestUtil.isNaNOrInf
+import org.apache.spark.sql.types._
+import org.apache.spark.unsafe.types.UTF8String
+
+trait RapidsTestsTrait extends RapidsTestsCommonTrait {
+
+ override def beforeAll(): Unit = {
+ // prepare working paths
+ val basePathDir = new File(basePath)
+ if (basePathDir.exists()) {
+ fu.forceDelete(basePathDir)
+ }
+ fu.forceMkdir(basePathDir)
+ fu.forceMkdir(new File(warehouse))
+ fu.forceMkdir(new File(metaStorePathAbsolute))
+ super.beforeAll()
+ initializeSession()
+ _spark.sparkContext.setLogLevel("WARN")
+ }
+
+ override def afterAll(): Unit = {
+ try {
+ super.afterAll()
+ } finally {
+ try {
+ if (_spark != null) {
+ try {
+ _spark.sessionState.catalog.reset()
+ } finally {
+ _spark.stop()
+ _spark = null
+ }
+ }
+ } finally {
+ SparkSession.clearActiveSession()
+ SparkSession.clearDefaultSession()
+ }
+ }
+ logInfo(
+ "Test suite: " + this.getClass.getSimpleName +
+ "; Suite test number: " + TestStats.suiteTestNumber +
+ "; OffloadRapids number: " + TestStats.offloadRapidsTestNumber + "\n")
+ TestStats.printMarkdown(this.getClass.getSimpleName)
+ TestStats.reset()
+ }
+
+ protected def initializeSession(): Unit = {
+ if (_spark == null) {
+ val sparkBuilder = SparkSession
+ .builder()
+ .master(s"local[2]")
+ // Avoid static evaluation for literal input by spark catalyst.
+ .config(
+ SQLConf.OPTIMIZER_EXCLUDED_RULES.key,
+ ConvertToLocalRelation.ruleName +
+ "," + ConstantFolding.ruleName + "," + NullPropagation.ruleName)
+ .config("spark.rapids.sql.enabled", "true")
+ .config("spark.plugins", "com.nvidia.spark.SQLPlugin")
+ .config("spark.sql.queryExecutionListeners",
+ "org.apache.spark.sql.rapids.ExecutionPlanCaptureCallback")
+ .config("spark.sql.warehouse.dir", warehouse)
+ .appName("rapids spark plugin running Vanilla Spark UT")
+
+ _spark = sparkBuilder
+ .config("spark.unsafe.exceptionOnMemoryLeak", "true")
+ .getOrCreate()
+ }
+ }
+
+ protected var _spark: SparkSession = null
+
+ override protected def checkEvaluation(
+ expression: => Expression,
+ expected: Any,
+ inputRow: InternalRow = EmptyRow): Unit = {
+ val resolver = ResolveTimeZone
+ val expr = resolver.resolveTimeZones(expression)
+ assert(expr.resolved)
+
+ if (canConvertToDataFrame(inputRow)) {
+ rapidsCheckExpression(expr, expected, inputRow)
+ } else {
+ logWarning(s"The status of this unit test is not guaranteed.")
+ val catalystValue = CatalystTypeConverters.convertToCatalyst(expected)
+ checkEvaluationWithoutCodegen(expr, catalystValue, inputRow)
+ checkEvaluationWithMutableProjection(expr, catalystValue, inputRow)
+ if (GenerateUnsafeProjection.canSupport(expr.dataType)) {
+ checkEvaluationWithUnsafeProjection(expr, catalystValue, inputRow)
+ }
+ checkEvaluationWithOptimization(expr, catalystValue, inputRow)
+ }
+ }
+
+ /**
+ * Sort map data by key and return the sorted key array and value array.
+ *
+ * @param input
+ * : input map data.
+ * @param kt
+ * : key type.
+ * @param vt
+ * : value type.
+ * @return
+ * the sorted key array and value array.
+ */
+ private def getSortedArrays(
+ input: MapData,
+ kt: DataType,
+ vt: DataType): (ArrayData, ArrayData) = {
+ val keyArray = input.keyArray().toArray[Any](kt)
+ val valueArray = input.valueArray().toArray[Any](vt)
+ val newMap = (keyArray.zip(valueArray)).toMap
+ val sortedMap = mutable.SortedMap(newMap.toSeq: _*)(TypeUtils.getInterpretedOrdering(kt))
+ (new GenericArrayData(sortedMap.keys.toArray), new GenericArrayData(sortedMap.values.toArray))
+ }
+
+ override protected def checkResult(
+ result: Any,
+ expected: Any,
+ exprDataType: DataType,
+ exprNullable: Boolean): Boolean = {
+ val dataType = UserDefinedType.sqlType(exprDataType)
+
+ // The result is null for a non-nullable expression
+ assert(result != null || exprNullable, "exprNullable should be true if result is null")
+ (result, expected) match {
+ case (result: Array[Byte], expected: Array[Byte]) =>
+ java.util.Arrays.equals(result, expected)
+ case (result: Double, expected: Spread[Double @unchecked]) =>
+ expected.asInstanceOf[Spread[Double]].isWithin(result)
+ case (result: InternalRow, expected: InternalRow) =>
+ val st = dataType.asInstanceOf[StructType]
+ assert(result.numFields == st.length && expected.numFields == st.length)
+ st.zipWithIndex.forall {
+ case (f, i) =>
+ checkResult(
+ result.get(i, f.dataType),
+ expected.get(i, f.dataType),
+ f.dataType,
+ f.nullable)
+ }
+ case (result: ArrayData, expected: ArrayData) =>
+ result.numElements == expected.numElements && {
+ val ArrayType(et, cn) = dataType.asInstanceOf[ArrayType]
+ var isSame = true
+ var i = 0
+ while (isSame && i < result.numElements) {
+ isSame = checkResult(result.get(i, et), expected.get(i, et), et, cn)
+ i += 1
+ }
+ isSame
+ }
+ case (result: MapData, expected: MapData) =>
+ val MapType(kt, vt, vcn) = dataType.asInstanceOf[MapType]
+ checkResult(
+ getSortedArrays(result, kt, vt)._1,
+ getSortedArrays(expected, kt, vt)._1,
+ ArrayType(kt, containsNull = false),
+ exprNullable = false) && checkResult(
+ getSortedArrays(result, kt, vt)._2,
+ getSortedArrays(expected, kt, vt)._2,
+ ArrayType(vt, vcn),
+ exprNullable = false)
+ case (result: Double, expected: Double) =>
+ if (
+ (isNaNOrInf(result) || isNaNOrInf(expected))
+ || (result == -0.0) || (expected == -0.0)
+ ) {
+ java.lang.Double.doubleToRawLongBits(result) ==
+ java.lang.Double.doubleToRawLongBits(expected)
+ } else {
+ Precision.equalsWithRelativeTolerance(result, expected, 0.00001d)
+ }
+ case (result: Float, expected: Float) =>
+ if (expected.isNaN) result.isNaN else expected == result
+ case (result: Row, expected: InternalRow) => result.toSeq == expected.toSeq(result.schema)
+ case _ =>
+ result == expected
+ }
+ }
+
+ def checkDataTypeSupported(expr: Expression): Boolean = {
+ RapidsTestConstants.SUPPORTED_DATA_TYPE.acceptsType(expr.dataType)
+ }
+
+ def rapidsCheckExpression(expression: Expression, expected: Any, inputRow: InternalRow): Unit = {
+ val df = if (inputRow != EmptyRow && inputRow != InternalRow.empty) {
+ convertInternalRowToDataFrame(inputRow)
+ } else {
+ val schema = StructType(StructField("a", IntegerType, nullable = true) :: Nil)
+ val empData = Seq(Row(1))
+ _spark.createDataFrame(_spark.sparkContext.parallelize(empData), schema)
+ }
+ val resultDF = df.select(Column(expression))
+ val result = resultDF.collect()
+ TestStats.testUnitNumber = TestStats.testUnitNumber + 1
+ if (
+ checkDataTypeSupported(expression) &&
+ expression.children.forall(checkDataTypeSupported)
+ ) {
+ val projectTransformer = resultDF.queryExecution.executedPlan.collect {
+ case p: GpuProjectExec => p
+ }
+ if (projectTransformer.size == 1) {
+ TestStats.offloadRapidsUnitNumber += 1
+ logInfo("Offload to native backend in the test.\n")
+ } else {
+ logInfo("Not supported in native backend, fall back to vanilla spark in the test.\n")
+ shouldNotFallback()
+ }
+ } else {
+ logInfo("Has unsupported data type, fall back to vanilla spark.\n")
+ shouldNotFallback()
+ }
+
+ if (
+ !(checkResult(result.head.get(0), expected, expression.dataType, expression.nullable)
+ || checkResult(
+ CatalystTypeConverters.createToCatalystConverter(expression.dataType)(
+ result.head.get(0)
+ ), // decimal precision is wrong from value
+ CatalystTypeConverters.convertToCatalyst(expected),
+ expression.dataType,
+ expression.nullable
+ ))
+ ) {
+ val input = if (inputRow == EmptyRow) "" else s", input: $inputRow"
+ fail(
+ s"Incorrect evaluation: $expression, " +
+ s"actual: ${result.head.get(0)}, " +
+ s"expected: $expected$input")
+ }
+ }
+
+ def shouldNotFallback(): Unit = {
+ TestStats.offloadRapids = false
+ }
+
+ def canConvertToDataFrame(inputRow: InternalRow): Boolean = {
+ if (inputRow == EmptyRow || inputRow == InternalRow.empty) {
+ return true
+ }
+ if (!inputRow.isInstanceOf[GenericInternalRow]) {
+ return false
+ }
+ val values = inputRow.asInstanceOf[GenericInternalRow].values
+ for (value <- values) {
+ value match {
+ case _: MapData => return false
+ case _: ArrayData => return false
+ case _: InternalRow => return false
+ case _ =>
+ }
+ }
+ true
+ }
+
+ def convertInternalRowToDataFrame(inputRow: InternalRow): DataFrame = {
+ val structFileSeq = new ArrayBuffer[StructField]()
+ val values = inputRow match {
+ case genericInternalRow: GenericInternalRow =>
+ genericInternalRow.values
+ case _ => throw new UnsupportedOperationException("Unsupported InternalRow.")
+ }
+ values.foreach {
+ case boolean: java.lang.Boolean =>
+ structFileSeq.append(StructField("bool", BooleanType, boolean == null))
+ case short: java.lang.Short =>
+ structFileSeq.append(StructField("i16", ShortType, short == null))
+ case byte: java.lang.Byte =>
+ structFileSeq.append(StructField("i8", ByteType, byte == null))
+ case integer: java.lang.Integer =>
+ structFileSeq.append(StructField("i32", IntegerType, integer == null))
+ case long: java.lang.Long =>
+ structFileSeq.append(StructField("i64", LongType, long == null))
+ case float: java.lang.Float =>
+ structFileSeq.append(StructField("fp32", FloatType, float == null))
+ case double: java.lang.Double =>
+ structFileSeq.append(StructField("fp64", DoubleType, double == null))
+ case utf8String: UTF8String =>
+ structFileSeq.append(StructField("str", StringType, utf8String == null))
+ case byteArr: Array[Byte] =>
+ structFileSeq.append(StructField("vbin", BinaryType, byteArr == null))
+ case decimal: Decimal =>
+ structFileSeq.append(
+ StructField("dec", DecimalType(decimal.precision, decimal.scale), decimal == null))
+ case null =>
+ structFileSeq.append(StructField("null", IntegerType, nullable = true))
+ case unsupported @ _ =>
+ throw new UnsupportedOperationException(s"Unsupported type: ${unsupported.getClass}")
+ }
+ val fields = structFileSeq.toSeq
+ _spark.internalCreateDataFrame(
+ _spark.sparkContext.parallelize(Seq(inputRow)),
+ StructType(fields))
+ }
+}