diff --git a/NOTICE b/NOTICE index eb8aaeb9977..657bfecd06a 100644 --- a/NOTICE +++ b/NOTICE @@ -48,6 +48,17 @@ The Apache Software Foundation (http://www.apache.org/). -------------------------------------------------------------------------------- +This project includes software from the Apache Gluten project +(www.github.com/apache/incubator-gluten/). + +Apache Gluten (Incubating) +Copyright (2024) The Apache Software Foundation + +This product includes software developed at +The Apache Software Foundation (http://www.apache.org/). + +-------------------------------------------------------------------------------- + This project includes code from Kite, developed at Cloudera, Inc. with the following copyright notice: diff --git a/pom.xml b/pom.xml index b45eca7abeb..c939f8d5891 100644 --- a/pom.xml +++ b/pom.xml @@ -1039,6 +1039,53 @@ ${mockito.version} test + + org.apache.spark + spark-core_${scala.binary.version} + ${spark.version} + test-jar + test + + + org.apache.hadoop + hadoop-client + + + org.apache.hadoop + hadoop-client-api + + + org.apache.hadoop + hadoop-client-runtime + + + org.apache.curator + curator-recipes + + + org.slf4j + slf4j-log4j12 + + + log4j + log4j + + + + + org.apache.spark + spark-catalyst_${scala.binary.version} + ${spark.version} + test-jar + test + + + org.apache.spark + spark-sql_${scala.binary.version} + ${spark.version} + test-jar + test + diff --git a/scala2.13/pom.xml b/scala2.13/pom.xml index c60e0b19034..cf3bfb48373 100644 --- a/scala2.13/pom.xml +++ b/scala2.13/pom.xml @@ -1039,6 +1039,53 @@ ${mockito.version} test + + org.apache.spark + spark-core_${scala.binary.version} + ${spark.version} + test-jar + test + + + org.apache.hadoop + hadoop-client + + + org.apache.hadoop + hadoop-client-api + + + org.apache.hadoop + hadoop-client-runtime + + + org.apache.curator + curator-recipes + + + org.slf4j + slf4j-log4j12 + + + log4j + log4j + + + + + org.apache.spark + spark-catalyst_${scala.binary.version} + ${spark.version} + test-jar + test + + + org.apache.spark + spark-sql_${scala.binary.version} + ${spark.version} + test-jar + test + diff --git a/scala2.13/tests/pom.xml b/scala2.13/tests/pom.xml index d90842c7955..a9c1d707b27 100644 --- a/scala2.13/tests/pom.xml +++ b/scala2.13/tests/pom.xml @@ -103,6 +103,27 @@ org.apache.spark spark-avro_${scala.binary.version} + + org.apache.spark + spark-core_${scala.binary.version} + test-jar + + + org.apache.spark + spark-sql_${scala.binary.version} + test-jar + + + org.apache.spark + spark-catalyst_${scala.binary.version} + test-jar + + + org.scalatestplus + scalatestplus-scalacheck_${scala.binary.version} + 3.1.0.0-RC2 + test + diff --git a/tests/pom.xml b/tests/pom.xml index 71143614f7b..96c2a051143 100644 --- a/tests/pom.xml +++ b/tests/pom.xml @@ -103,6 +103,27 @@ org.apache.spark spark-avro_${scala.binary.version} + + org.apache.spark + spark-core_${scala.binary.version} + test-jar + + + org.apache.spark + spark-sql_${scala.binary.version} + test-jar + + + org.apache.spark + spark-catalyst_${scala.binary.version} + test-jar + + + org.scalatestplus + scalatestplus-scalacheck_${scala.binary.version} + 3.1.0.0-RC2 + test + diff --git a/tests/src/test/java/com/nvidia/spark/rapids/TestStats.java b/tests/src/test/java/com/nvidia/spark/rapids/TestStats.java new file mode 100644 index 00000000000..3f367529bcd --- /dev/null +++ b/tests/src/test/java/com/nvidia/spark/rapids/TestStats.java @@ -0,0 +1,205 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.nvidia.spark.rapids; + +import java.util.Collections; +import java.util.HashMap; +import java.util.HashSet; +import java.util.Map; +import java.util.Set; +import java.util.Stack; + +/** Only use in UT Env. It's not thread safe. */ +public class TestStats { + private static final String HEADER_FORMAT = "%s%s"; + private static final String ROW_FORMAT = + "%s%s%s%s%s%s"; + + private static boolean UT_ENV = false; + private static final Map caseInfos = new HashMap<>(); + private static String currentCase; + public static int offloadRapidsUnitNumber = 0; + public static int testUnitNumber = 0; + + // use the rapids backend to execute the query + public static boolean offloadRapids = true; + public static int suiteTestNumber = 0; + public static int offloadRapidsTestNumber = 0; + + public static void beginStatistic() { + UT_ENV = true; + } + + public static void reset() { + offloadRapids = false; + suiteTestNumber = 0; + offloadRapidsTestNumber = 0; + testUnitNumber = 0; + offloadRapidsUnitNumber = 0; + resetCase(); + caseInfos.clear(); + } + + private static int totalSuiteTestNumber = 0; + public static int totalOffloadRapidsTestNumber = 0; + + public static int totalTestUnitNumber = 0; + public static int totalOffloadRapidsCaseNumber = 0; + + public static void printMarkdown(String suitName) { + if (!UT_ENV) { + return; + } + + String title = "print_markdown_" + suitName; + + String info = + "Case Count: %d, OffloadRapids Case Count: %d, " + + "Unit Count %d, OffloadRapids Unit Count %d"; + + System.out.println( + String.format( + HEADER_FORMAT, + title, + String.format( + info, + TestStats.suiteTestNumber, + TestStats.offloadRapidsTestNumber, + TestStats.testUnitNumber, + TestStats.offloadRapidsUnitNumber))); + + caseInfos.forEach( + (key, value) -> + System.out.println( + String.format( + ROW_FORMAT, + title, + key, + value.status, + value.type, + String.join("
", value.fallbackExpressionName), + String.join("
", value.fallbackClassName)))); + + totalSuiteTestNumber += suiteTestNumber; + totalOffloadRapidsTestNumber += offloadRapidsTestNumber; + totalTestUnitNumber += testUnitNumber; + totalOffloadRapidsCaseNumber += offloadRapidsUnitNumber; + System.out.println( + "total_markdown_ totalCaseNum:" + + totalSuiteTestNumber + + " offloadRapids: " + + totalOffloadRapidsTestNumber + + " total unit: " + + totalTestUnitNumber + + " offload unit: " + + totalOffloadRapidsCaseNumber); + } + + public static void addFallBackClassName(String className) { + if (!UT_ENV) { + return; + } + + if (caseInfos.containsKey(currentCase) && !caseInfos.get(currentCase).stack.isEmpty()) { + CaseInfo info = caseInfos.get(currentCase); + caseInfos.get(currentCase).fallbackExpressionName.add(info.stack.pop()); + caseInfos.get(currentCase).fallbackClassName.add(className); + } + } + + public static void addFallBackCase() { + if (!UT_ENV) { + return; + } + + if (caseInfos.containsKey(currentCase)) { + caseInfos.get(currentCase).type = "fallback"; + } + } + + public static void addExpressionClassName(String className) { + if (!UT_ENV) { + return; + } + + if (caseInfos.containsKey(currentCase)) { + CaseInfo info = caseInfos.get(currentCase); + info.stack.add(className); + } + } + + public static Set getFallBackClassName() { + if (!UT_ENV) { + return Collections.emptySet(); + } + + if (caseInfos.containsKey(currentCase)) { + return Collections.unmodifiableSet(caseInfos.get(currentCase).fallbackExpressionName); + } + + return Collections.emptySet(); + } + + public static void addIgnoreCaseName(String caseName) { + if (!UT_ENV) { + return; + } + + if (caseInfos.containsKey(caseName)) { + caseInfos.get(caseName).type = "fatal"; + } + } + + public static void resetCase() { + if (!UT_ENV) { + return; + } + + if (caseInfos.containsKey(currentCase)) { + caseInfos.get(currentCase).stack.clear(); + } + currentCase = ""; + } + + public static void startCase(String caseName) { + if (!UT_ENV) { + return; + } + + caseInfos.putIfAbsent(caseName, new CaseInfo()); + currentCase = caseName; + } + + public static void endCase(boolean status) { + if (!UT_ENV) { + return; + } + + if (caseInfos.containsKey(currentCase)) { + caseInfos.get(currentCase).status = status ? "success" : "error"; + } + + resetCase(); + } +} + +class CaseInfo { + final Stack stack = new Stack<>(); + Set fallbackExpressionName = new HashSet<>(); + Set fallbackClassName = new HashSet<>(); + String type = ""; + String status = ""; +} diff --git a/tests/src/test/spark330/scala/org/apache/spark/sql/rapids/suites/RapidsCastSuite.scala b/tests/src/test/spark330/scala/org/apache/spark/sql/rapids/suites/RapidsCastSuite.scala new file mode 100644 index 00000000000..f3fec27f7f6 --- /dev/null +++ b/tests/src/test/spark330/scala/org/apache/spark/sql/rapids/suites/RapidsCastSuite.scala @@ -0,0 +1,39 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/*** spark-rapids-shim-json-lines +{"spark": "330"} +spark-rapids-shim-json-lines ***/ +package org.apache.spark.sql.rapids.suites + +import org.apache.spark.sql.catalyst.expressions.{Cast, CastBase, CastSuite, Expression, Literal} +import org.apache.spark.sql.rapids.utils.RapidsTestsTrait +import org.apache.spark.sql.types._ + +class RapidsCastSuite extends CastSuite with RapidsTestsTrait { + // example to enhance logging for base suite + override def cast(v: Any, targetType: DataType, timeZoneId: Option[String] = None): CastBase = { + v match { + case lit: Expression => + logDebug(s"Cast from: ${lit.dataType.typeName}, to: ${targetType.typeName}") + Cast(lit, targetType, timeZoneId) + case _ => + val lit = Literal(v) + logDebug(s"Cast from: ${lit.dataType.typeName}, to: ${targetType.typeName}") + Cast(lit, targetType, timeZoneId) + } + } +} diff --git a/tests/src/test/spark330/scala/org/apache/spark/sql/rapids/suites/RapidsDataFrameAggregateSuite.scala b/tests/src/test/spark330/scala/org/apache/spark/sql/rapids/suites/RapidsDataFrameAggregateSuite.scala new file mode 100644 index 00000000000..5a394a5b0e8 --- /dev/null +++ b/tests/src/test/spark330/scala/org/apache/spark/sql/rapids/suites/RapidsDataFrameAggregateSuite.scala @@ -0,0 +1,30 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/*** spark-rapids-shim-json-lines +{"spark": "330"} +spark-rapids-shim-json-lines ***/ +package org.apache.spark.sql.rapids.suites + +import org.apache.spark.sql.DataFrameAggregateSuite +import org.apache.spark.sql.rapids.utils.RapidsSQLTestsTrait + +class RapidsDataFrameAggregateSuite extends DataFrameAggregateSuite with RapidsSQLTestsTrait { + // example to show how to replace the logic of an excluded test case in Vanilla Spark + testRapids("collect functions" ) { // "collect functions" was excluded at RapidsTestSettings + // println("...") + } +} diff --git a/tests/src/test/spark330/scala/org/apache/spark/sql/rapids/suites/RapidsJsonFunctionsSuite.scala b/tests/src/test/spark330/scala/org/apache/spark/sql/rapids/suites/RapidsJsonFunctionsSuite.scala new file mode 100644 index 00000000000..43150c0df4b --- /dev/null +++ b/tests/src/test/spark330/scala/org/apache/spark/sql/rapids/suites/RapidsJsonFunctionsSuite.scala @@ -0,0 +1,25 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/*** spark-rapids-shim-json-lines +{"spark": "330"} +spark-rapids-shim-json-lines ***/ +package org.apache.spark.sql.rapids.suites + +import org.apache.spark.sql.JsonFunctionsSuite +import org.apache.spark.sql.rapids.utils.RapidsSQLTestsTrait + +class RapidsJsonFunctionsSuite extends JsonFunctionsSuite with RapidsSQLTestsTrait {} diff --git a/tests/src/test/spark330/scala/org/apache/spark/sql/rapids/suites/RapidsJsonSuite.scala b/tests/src/test/spark330/scala/org/apache/spark/sql/rapids/suites/RapidsJsonSuite.scala new file mode 100644 index 00000000000..6d244c67ad0 --- /dev/null +++ b/tests/src/test/spark330/scala/org/apache/spark/sql/rapids/suites/RapidsJsonSuite.scala @@ -0,0 +1,87 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/*** spark-rapids-shim-json-lines +{"spark": "330"} +spark-rapids-shim-json-lines ***/ +package org.apache.spark.sql.rapids.suites + +import org.apache.spark.SparkConf +import org.apache.spark.sql.execution.datasources.{InMemoryFileIndex, NoopCache} +import org.apache.spark.sql.execution.datasources.json.JsonSuite +import org.apache.spark.sql.execution.datasources.v2.json.JsonScanBuilder +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.rapids.utils.RapidsSQLTestsBaseTrait +import org.apache.spark.sql.sources +import org.apache.spark.sql.types.{IntegerType, StructType} +import org.apache.spark.sql.util.CaseInsensitiveStringMap + +class RapidsJsonSuite extends JsonSuite with RapidsSQLTestsBaseTrait { + + /** Returns full path to the given file in the resource folder */ + override protected def testFile(fileName: String): String = { + getWorkspaceFilePath("sql", "core", "src", "test", "resources").toString + "/" + fileName + } +} + +class RapidsJsonV1Suite extends RapidsJsonSuite with RapidsSQLTestsBaseTrait { + override def sparkConf: SparkConf = + super.sparkConf + .set(SQLConf.USE_V1_SOURCE_LIST, "json") +} + +class RapidsJsonV2Suite extends RapidsJsonSuite with RapidsSQLTestsBaseTrait { + override def sparkConf: SparkConf = + super.sparkConf + .set(SQLConf.USE_V1_SOURCE_LIST, "") + + test("get pushed filters") { + val attr = "col" + def getBuilder(path: String): JsonScanBuilder = { + val fileIndex = new InMemoryFileIndex( + spark, + Seq(new org.apache.hadoop.fs.Path(path, "file.json")), + Map.empty, + None, + NoopCache) + val schema = new StructType().add(attr, IntegerType) + val options = CaseInsensitiveStringMap.empty() + new JsonScanBuilder(spark, fileIndex, schema, schema, options) + } + val filters: Array[sources.Filter] = Array(sources.IsNotNull(attr)) + withSQLConf(SQLConf.JSON_FILTER_PUSHDOWN_ENABLED.key -> "true") { + withTempPath { + file => + val scanBuilder = getBuilder(file.getCanonicalPath) + assert(scanBuilder.pushDataFilters(filters) === filters) + } + } + + withSQLConf(SQLConf.JSON_FILTER_PUSHDOWN_ENABLED.key -> "false") { + withTempPath { + file => + val scanBuilder = getBuilder(file.getCanonicalPath) + assert(scanBuilder.pushDataFilters(filters) === Array.empty[sources.Filter]) + } + } + } +} + +class RapidsJsonLegacyTimeParserSuite extends RapidsJsonSuite with RapidsSQLTestsBaseTrait { + override def sparkConf: SparkConf = + super.sparkConf + .set(SQLConf.LEGACY_TIME_PARSER_POLICY, "legacy") +} diff --git a/tests/src/test/spark330/scala/org/apache/spark/sql/rapids/suites/RapidsMathFunctionsSuite.scala b/tests/src/test/spark330/scala/org/apache/spark/sql/rapids/suites/RapidsMathFunctionsSuite.scala new file mode 100644 index 00000000000..55b4b00f680 --- /dev/null +++ b/tests/src/test/spark330/scala/org/apache/spark/sql/rapids/suites/RapidsMathFunctionsSuite.scala @@ -0,0 +1,26 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/*** spark-rapids-shim-json-lines +{"spark": "330"} +spark-rapids-shim-json-lines ***/ +package org.apache.spark.sql.rapids.suites + +import org.apache.spark.sql.MathFunctionsSuite +import org.apache.spark.sql.rapids.utils.RapidsSQLTestsTrait + +class RapidsMathFunctionsSuite extends MathFunctionsSuite with RapidsSQLTestsTrait { +} diff --git a/tests/src/test/spark330/scala/org/apache/spark/sql/rapids/suites/RapidsRegexpExpressionsSuite.scala b/tests/src/test/spark330/scala/org/apache/spark/sql/rapids/suites/RapidsRegexpExpressionsSuite.scala new file mode 100644 index 00000000000..95b54240dbe --- /dev/null +++ b/tests/src/test/spark330/scala/org/apache/spark/sql/rapids/suites/RapidsRegexpExpressionsSuite.scala @@ -0,0 +1,25 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/*** spark-rapids-shim-json-lines +{"spark": "330"} +spark-rapids-shim-json-lines ***/ +package org.apache.spark.sql.rapids.suites + +import org.apache.spark.sql.catalyst.expressions.RegexpExpressionsSuite +import org.apache.spark.sql.rapids.utils.RapidsTestsTrait + +class RapidsRegexpExpressionsSuite extends RegexpExpressionsSuite with RapidsTestsTrait {} diff --git a/tests/src/test/spark330/scala/org/apache/spark/sql/rapids/suites/RapidsStringExpressionsSuite.scala b/tests/src/test/spark330/scala/org/apache/spark/sql/rapids/suites/RapidsStringExpressionsSuite.scala new file mode 100644 index 00000000000..164406fdf83 --- /dev/null +++ b/tests/src/test/spark330/scala/org/apache/spark/sql/rapids/suites/RapidsStringExpressionsSuite.scala @@ -0,0 +1,25 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/*** spark-rapids-shim-json-lines +{"spark": "330"} +spark-rapids-shim-json-lines ***/ +package org.apache.spark.sql.rapids.suites + +import org.apache.spark.sql.catalyst.expressions.StringExpressionsSuite +import org.apache.spark.sql.rapids.utils.RapidsTestsTrait + +class RapidsStringExpressionsSuite extends StringExpressionsSuite with RapidsTestsTrait {} diff --git a/tests/src/test/spark330/scala/org/apache/spark/sql/rapids/suites/RapidsStringFunctionsSuite.scala b/tests/src/test/spark330/scala/org/apache/spark/sql/rapids/suites/RapidsStringFunctionsSuite.scala new file mode 100644 index 00000000000..7b4a8ac6d7d --- /dev/null +++ b/tests/src/test/spark330/scala/org/apache/spark/sql/rapids/suites/RapidsStringFunctionsSuite.scala @@ -0,0 +1,28 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/*** spark-rapids-shim-json-lines +{"spark": "330"} +spark-rapids-shim-json-lines ***/ +package org.apache.spark.sql.rapids.suites + +import org.apache.spark.sql.StringFunctionsSuite +import org.apache.spark.sql.rapids.utils.RapidsSQLTestsTrait + +class RapidsStringFunctionsSuite + extends StringFunctionsSuite + with RapidsSQLTestsTrait { +} diff --git a/tests/src/test/spark330/scala/org/apache/spark/sql/rapids/utils/BackendTestSettings.scala b/tests/src/test/spark330/scala/org/apache/spark/sql/rapids/utils/BackendTestSettings.scala new file mode 100644 index 00000000000..83396e977fa --- /dev/null +++ b/tests/src/test/spark330/scala/org/apache/spark/sql/rapids/utils/BackendTestSettings.scala @@ -0,0 +1,215 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/*** spark-rapids-shim-json-lines +{"spark": "330"} +spark-rapids-shim-json-lines ***/ +package org.apache.spark.sql.rapids.utils + +import java.util + +import scala.collection.JavaConverters._ +import scala.reflect.ClassTag + +import com.nvidia.spark.rapids.TestStats + +import org.apache.spark.sql.rapids.utils.RapidsTestConstants.RAPIDS_TEST + +abstract class BackendTestSettings { + + private val enabledSuites: java.util.Map[String, SuiteSettings] = new util.HashMap() + + protected def enableSuite[T: ClassTag]: SuiteSettings = { + val suiteName = implicitly[ClassTag[T]].runtimeClass.getCanonicalName + if (enabledSuites.containsKey(suiteName)) { + throw new IllegalArgumentException("Duplicated suite name: " + suiteName) + } + val suiteSettings = new SuiteSettings + enabledSuites.put(suiteName, suiteSettings) + suiteSettings + } + + private[utils] def shouldRun(suiteName: String, testName: String): Boolean = { + if (!enabledSuites.containsKey(suiteName)) { + return false + } + + val suiteSettings = enabledSuites.get(suiteName) + + val inclusion = suiteSettings.inclusion.asScala + val exclusion = suiteSettings.exclusion.asScala + + if (inclusion.isEmpty && exclusion.isEmpty) { + // default to run all cases under this suite + return true + } + + if (inclusion.nonEmpty && exclusion.nonEmpty) { + // error + throw new IllegalStateException( + s"Do not use include and exclude conditions on the same test case: $suiteName:$testName") + } + + if (inclusion.nonEmpty) { + // include mode + val isIncluded = inclusion.exists(_.isIncluded(testName)) + return isIncluded + } + + if (exclusion.nonEmpty) { + // exclude mode + val isExcluded = exclusion.exists(_.isExcluded(testName)) + return !isExcluded + } + + throw new IllegalStateException("Unreachable code") + } + + sealed trait ExcludeReason + // The reason should most likely to be a issue link, + // or a description like "This simply can't work on GPU". + // It should never be "unknown" or "need investigation" + case class KNOWN_ISSUE(reason: String) extends ExcludeReason + case class WONT_FIX_ISSUE(reason: String) extends ExcludeReason + + + final protected class SuiteSettings { + private[utils] val inclusion: util.List[IncludeBase] = new util.ArrayList() + private[utils] val exclusion: util.List[ExcludeBase] = new util.ArrayList() + private[utils] val excludeReasons: util.List[ExcludeReason] = new util.ArrayList() + + def include(testNames: String*): SuiteSettings = { + inclusion.add(Include(testNames: _*)) + this + } + def exclude(testNames: String, reason: ExcludeReason): SuiteSettings = { + exclusion.add(Exclude(testNames)) + excludeReasons.add(reason) + this + } + def includeRapidsTest(testName: String*): SuiteSettings = { + inclusion.add(IncludeRapidsTest(testName: _*)) + this + } + def excludeRapidsTest(testName: String, reason: ExcludeReason): SuiteSettings = { + exclusion.add(ExcludeRapidsTest(testName)) + excludeReasons.add(reason) + this + } + def includeByPrefix(prefixes: String*): SuiteSettings = { + inclusion.add(IncludeByPrefix(prefixes: _*)) + this + } + def excludeByPrefix(prefixes: String, reason: ExcludeReason): SuiteSettings = { + exclusion.add(ExcludeByPrefix(prefixes)) + excludeReasons.add(reason) + this + } + def includeRapidsTestsByPrefix(prefixes: String*): SuiteSettings = { + inclusion.add(IncludeRapidsTestByPrefix(prefixes: _*)) + this + } + def excludeRapidsTestsByPrefix(prefixes: String, reason: ExcludeReason): SuiteSettings = { + exclusion.add(ExcludeRadpisTestByPrefix(prefixes)) + excludeReasons.add(reason) + this + } + def includeAllRapidsTests(): SuiteSettings = { + inclusion.add(IncludeByPrefix(RAPIDS_TEST)) + this + } + def excludeAllRapidsTests(reason: ExcludeReason): SuiteSettings = { + exclusion.add(ExcludeByPrefix(RAPIDS_TEST)) + excludeReasons.add(reason) + this + } + } + + protected trait IncludeBase { + def isIncluded(testName: String): Boolean + } + protected trait ExcludeBase { + def isExcluded(testName: String): Boolean + } + private case class Include(testNames: String*) extends IncludeBase { + val nameSet: Set[String] = Set(testNames: _*) + override def isIncluded(testName: String): Boolean = nameSet.contains(testName) + } + private case class Exclude(testNames: String*) extends ExcludeBase { + val nameSet: Set[String] = Set(testNames: _*) + override def isExcluded(testName: String): Boolean = nameSet.contains(testName) + } + private case class IncludeRapidsTest(testNames: String*) extends IncludeBase { + val nameSet: Set[String] = testNames.map(name => RAPIDS_TEST + name).toSet + override def isIncluded(testName: String): Boolean = nameSet.contains(testName) + } + private case class ExcludeRapidsTest(testNames: String*) extends ExcludeBase { + val nameSet: Set[String] = testNames.map(name => RAPIDS_TEST + name).toSet + override def isExcluded(testName: String): Boolean = nameSet.contains(testName) + } + private case class IncludeByPrefix(prefixes: String*) extends IncludeBase { + override def isIncluded(testName: String): Boolean = { + if (prefixes.exists(prefix => testName.startsWith(prefix))) { + return true + } + false + } + } + private case class ExcludeByPrefix(prefixes: String*) extends ExcludeBase { + override def isExcluded(testName: String): Boolean = { + if (prefixes.exists(prefix => testName.startsWith(prefix))) { + return true + } + false + } + } + private case class IncludeRapidsTestByPrefix(prefixes: String*) extends IncludeBase { + override def isIncluded(testName: String): Boolean = { + if (prefixes.exists(prefix => testName.startsWith(RAPIDS_TEST + prefix))) { + return true + } + false + } + } + private case class ExcludeRadpisTestByPrefix(prefixes: String*) extends ExcludeBase { + override def isExcluded(testName: String): Boolean = { + if (prefixes.exists(prefix => testName.startsWith(RAPIDS_TEST + prefix))) { + return true + } + false + } + } +} + +object BackendTestSettings { + val instance: BackendTestSettings = { + Class + .forName("org.apache.spark.sql.rapids.utils.RapidsTestSettings") + .getDeclaredConstructor() + .newInstance() + .asInstanceOf[BackendTestSettings] + } + + def shouldRun(suiteName: String, testName: String): Boolean = { + val v = instance.shouldRun(suiteName, testName: String) + + if (!v) { + TestStats.addIgnoreCaseName(testName) + } + + v + } +} diff --git a/tests/src/test/spark330/scala/org/apache/spark/sql/rapids/utils/RapidsSQLTestsBaseTrait.scala b/tests/src/test/spark330/scala/org/apache/spark/sql/rapids/utils/RapidsSQLTestsBaseTrait.scala new file mode 100644 index 00000000000..540c70a2ee1 --- /dev/null +++ b/tests/src/test/spark330/scala/org/apache/spark/sql/rapids/utils/RapidsSQLTestsBaseTrait.scala @@ -0,0 +1,128 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/*** spark-rapids-shim-json-lines +{"spark": "330"} +spark-rapids-shim-json-lines ***/ +package org.apache.spark.sql.rapids.utils + +import java.util.{Locale, TimeZone} + +import org.scalactic.source.Position +import org.scalatest.Tag + +import org.apache.spark.SparkConf +import org.apache.spark.internal.config.Tests.IS_TESTING +import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.execution.SparkPlan +import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanExec, ShuffleQueryStageExec} +import org.apache.spark.sql.rapids.utils.RapidsTestConstants.RAPIDS_TEST +import org.apache.spark.sql.test.SharedSparkSession + + +/** Basic trait for Rapids SQL test cases. */ +trait RapidsSQLTestsBaseTrait extends SharedSparkSession with RapidsTestsBaseTrait { + + protected override def afterAll(): Unit = { + // SparkFunSuite will set this to true, and forget to reset to false + System.clearProperty(IS_TESTING.key) + super.afterAll() + } + + protected def testRapids(testName: String, testTag: Tag*)(testFun: => Any)(implicit + pos: Position): Unit = { + test(RAPIDS_TEST + testName, testTag: _*)(testFun) + } + + override protected def test(testName: String, testTags: Tag*)(testFun: => Any)(implicit + pos: Position): Unit = { + if (shouldRun(testName)) { + super.test(testName, testTags: _*)(testFun) + } else { + super.ignore(testName, testTags: _*)(testFun) + } + } + + override def sparkConf: SparkConf = { + RapidsSQLTestsBaseTrait.nativeSparkConf(super.sparkConf, warehouse) + } + + /** + * Get all the children plan of plans. + * + * @param plans + * : the input plans. + * @return + */ + private def getChildrenPlan(plans: Seq[SparkPlan]): Seq[SparkPlan] = { + if (plans.isEmpty) { + return Seq() + } + + val inputPlans: Seq[SparkPlan] = plans.map { + case stage: ShuffleQueryStageExec => stage.plan + case plan => plan + } + + var newChildren: Seq[SparkPlan] = Seq() + inputPlans.foreach { + plan => + newChildren = newChildren ++ getChildrenPlan(plan.children) + // To avoid duplication of WholeStageCodegenXXX and its children. + if (!plan.nodeName.startsWith("WholeStageCodegen")) { + newChildren = newChildren :+ plan + } + } + newChildren + } + + /** + * Get the executed plan of a data frame. + * + * @param df + * : dataframe. + * @return + * A sequence of executed plans. + */ + def getExecutedPlan(df: DataFrame): Seq[SparkPlan] = { + df.queryExecution.executedPlan match { + case exec: AdaptiveSparkPlanExec => + getChildrenPlan(Seq(exec.executedPlan)) + case plan => + getChildrenPlan(Seq(plan)) + } + } +} + +object RapidsSQLTestsBaseTrait { + def nativeSparkConf(origin: SparkConf, warehouse: String): SparkConf = { + // Timezone is fixed to UTC to allow timestamps to work by default + TimeZone.setDefault(TimeZone.getTimeZone("UTC")) + // Add Locale setting + Locale.setDefault(Locale.US) + + val conf = origin + .set("spark.rapids.sql.enabled", "true") + .set("spark.plugins", "com.nvidia.spark.SQLPlugin") + .set("spark.sql.queryExecutionListeners", + "org.apache.spark.sql.rapids.ExecutionPlanCaptureCallback") + .set("spark.sql.warehouse.dir", warehouse) + .set("spark.sql.cache.serializer", "com.nvidia.spark.ParquetCachedBatchSerializer") + .setAppName("rapids spark plugin running Vanilla Spark UT") + + conf + } +} diff --git a/tests/src/test/spark330/scala/org/apache/spark/sql/rapids/utils/RapidsSQLTestsTrait.scala b/tests/src/test/spark330/scala/org/apache/spark/sql/rapids/utils/RapidsSQLTestsTrait.scala new file mode 100644 index 00000000000..4358e29630c --- /dev/null +++ b/tests/src/test/spark330/scala/org/apache/spark/sql/rapids/utils/RapidsSQLTestsTrait.scala @@ -0,0 +1,273 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/*** spark-rapids-shim-json-lines +{"spark": "330"} +spark-rapids-shim-json-lines ***/ +package org.apache.spark.sql.rapids.utils + +import java.io.File +import java.util.TimeZone + +import scala.collection.JavaConverters._ + +import org.apache.commons.io.{FileUtils => fu} +import org.apache.commons.math3.util.Precision +import org.scalatest.Assertions + +import org.apache.spark.sql.{AnalysisException, DataFrame, QueryTest, Row} +import org.apache.spark.sql.catalyst.plans.logical +import org.apache.spark.sql.catalyst.util.{sideBySide, stackTraceToString} +import org.apache.spark.sql.execution.SQLExecution + + +/** Basic trait for Rapids SQL test cases. */ +trait RapidsSQLTestsTrait extends QueryTest with RapidsSQLTestsBaseTrait { + + def prepareWorkDir(): Unit = { + // prepare working paths + val basePathDir = new File(basePath) + if (basePathDir.exists()) { + fu.forceDelete(basePathDir) + } + fu.forceMkdir(basePathDir) + fu.forceMkdir(new File(warehouse)) + fu.forceMkdir(new File(metaStorePathAbsolute)) + } + + override def beforeAll(): Unit = { + prepareWorkDir() + super.beforeAll() + spark.sparkContext.setLogLevel("WARN") + } + + override def afterAll(): Unit = { + super.afterAll() + } + + override protected def checkAnswer(df: => DataFrame, expectedAnswer: Seq[Row]): Unit = { + val analyzedDF = + try df + catch { + case ae: AnalysisException => + val plan = ae.plan + if (plan.isDefined) { + fail(s""" + |Failed to analyze query: $ae + |${plan.get} + | + |${stackTraceToString(ae)} + |""".stripMargin) + } else { + throw ae + } + } + + assertEmptyMissingInput(analyzedDF) + + RapidsQueryTestUtil.checkAnswer(analyzedDF, expectedAnswer) + } +} + +object RapidsQueryTestUtil extends Assertions { + + /** + * Runs the plan and makes sure the answer matches the expected result. + * + * @param df + * the DataFrame to be executed + * @param expectedAnswer + * the expected result in a Seq of Rows. + * @param checkToRDD + * whether to verify deserialization to an RDD. This runs the query twice. + */ + def checkAnswer(df: DataFrame, expectedAnswer: Seq[Row], checkToRDD: Boolean = true): Unit = { + getErrorMessageInCheckAnswer(df, expectedAnswer, checkToRDD) match { + case Some(errorMessage) => fail(errorMessage) + case None => + } + } + + /** + * Runs the plan and makes sure the answer matches the expected result. If there was exception + * during the execution or the contents of the DataFrame does not match the expected result, an + * error message will be returned. Otherwise, a None will be returned. + * + * @param df + * the DataFrame to be executed + * @param expectedAnswer + * the expected result in a Seq of Rows. + * @param checkToRDD + * whether to verify deserialization to an RDD. This runs the query twice. + */ + def getErrorMessageInCheckAnswer( + df: DataFrame, + expectedAnswer: Seq[Row], + checkToRDD: Boolean = true): Option[String] = { + val isSorted = df.logicalPlan.collect { case s: logical.Sort => s }.nonEmpty + if (checkToRDD) { + SQLExecution.withSQLConfPropagated(df.sparkSession) { + df.rdd.count() // Also attempt to deserialize as an RDD [SPARK-15791] + } + } + + val sparkAnswer = + try df.collect().toSeq + catch { + case e: Exception => + val errorMessage = + s""" + |Exception thrown while executing query: + |${df.queryExecution} + |== Exception == + |$e + |${org.apache.spark.sql.catalyst.util.stackTraceToString(e)} + """.stripMargin + return Some(errorMessage) + } + + sameRows(expectedAnswer, sparkAnswer, isSorted).map { + results => + s""" + |Results do not match for query: + |Timezone: ${TimeZone.getDefault} + |Timezone Env: ${sys.env.getOrElse("TZ", "")} + | + |${df.queryExecution} + |== Results == + |$results + """.stripMargin + } + } + + def prepareAnswer(answer: Seq[Row], isSorted: Boolean): Seq[Row] = { + // Converts data to types that we can do equality comparison using Scala collections. + // For BigDecimal type, the Scala type has a better definition of equality test (similar to + // Java's java.math.BigDecimal.compareTo). + // For binary arrays, we convert it to Seq to avoid of calling java.util.Arrays.equals for + // equality test. + val converted: Seq[Row] = answer.map(prepareRow) + if (!isSorted) converted.sortBy(_.toString()) else converted + } + + // We need to call prepareRow recursively to handle schemas with struct types. + def prepareRow(row: Row): Row = { + Row.fromSeq(row.toSeq.map { + case null => null + case bd: java.math.BigDecimal => BigDecimal(bd) + // Equality of WrappedArray differs for AnyVal and AnyRef in Scala 2.12.2+ + case seq: Seq[_] => + seq.map { + case b: java.lang.Byte => b.byteValue + case s: java.lang.Short => s.shortValue + case i: java.lang.Integer => i.intValue + case l: java.lang.Long => l.longValue + case f: java.lang.Float => f.floatValue + case d: java.lang.Double => d.doubleValue + case x => x + } + // Convert array to Seq for easy equality check. + case b: Array[_] => b.toSeq + case r: Row => prepareRow(r) + case o => o + }) + } + + private def genError( + expectedAnswer: Seq[Row], + sparkAnswer: Seq[Row], + isSorted: Boolean): String = { + val getRowType: Option[Row] => String = row => + row + .map( + row => + if (row.schema == null) { + "struct<>" + } else { + s"${row.schema.catalogString}" + }) + .getOrElse("struct<>") + + s""" + |== Results == + |${sideBySide( + s"== Correct Answer - ${expectedAnswer.size} ==" +: + getRowType(expectedAnswer.headOption) +: + prepareAnswer(expectedAnswer, isSorted).map(_.toString()), + s"== RAPIDS Answer - ${sparkAnswer.size} ==" +: + getRowType(sparkAnswer.headOption) +: + prepareAnswer(sparkAnswer, isSorted).map(_.toString()) + ).mkString("\n")} + """.stripMargin + } + + def includesRows(expectedRows: Seq[Row], sparkAnswer: Seq[Row]): Option[String] = { + if (!prepareAnswer(expectedRows, true).toSet.subsetOf(prepareAnswer(sparkAnswer, true).toSet)) { + return Some(genError(expectedRows, sparkAnswer, true)) + } + None + } + + private def compare(obj1: Any, obj2: Any): Boolean = (obj1, obj2) match { + case (null, null) => true + case (null, _) => false + case (_, null) => false + case (a: Array[_], b: Array[_]) => + a.length == b.length && a.zip(b).forall { case (l, r) => compare(l, r) } + case (a: Map[_, _], b: Map[_, _]) => + a.size == b.size && a.keys.forall { + aKey => b.keys.find(bKey => compare(aKey, bKey)).exists(bKey => compare(a(aKey), b(bKey))) + } + case (a: Iterable[_], b: Iterable[_]) => + a.size == b.size && a.zip(b).forall { case (l, r) => compare(l, r) } + case (a: Product, b: Product) => + compare(a.productIterator.toSeq, b.productIterator.toSeq) + case (a: Row, b: Row) => + compare(a.toSeq, b.toSeq) + // 0.0 == -0.0, turn float/double to bits before comparison, to distinguish 0.0 and -0.0. + case (a: Double, b: Double) => + if ((isNaNOrInf(a) || isNaNOrInf(b)) || (a == -0.0) || (b == -0.0)) { + java.lang.Double.doubleToRawLongBits(a) == java.lang.Double.doubleToRawLongBits(b) + } else { + Precision.equalsWithRelativeTolerance(a, b, 0.00001d) + } + case (a: Float, b: Float) => + java.lang.Float.floatToRawIntBits(a) == java.lang.Float.floatToRawIntBits(b) + case (a, b) => a == b + } + + def isNaNOrInf(num: Double): Boolean = { + num.isNaN || num.isInfinite || num.isNegInfinity || num.isPosInfinity + } + + def sameRows( + expectedAnswer: Seq[Row], + sparkAnswer: Seq[Row], + isSorted: Boolean = false): Option[String] = { + // modify method 'compare' + if (!compare(prepareAnswer(expectedAnswer, isSorted), prepareAnswer(sparkAnswer, isSorted))) { + return Some(genError(expectedAnswer, sparkAnswer, isSorted)) + } + None + } + + def checkAnswer(df: DataFrame, expectedAnswer: java.util.List[Row]): Unit = { + getErrorMessageInCheckAnswer(df, expectedAnswer.asScala.toSeq) match { + case Some(errorMessage) => fail(errorMessage) + case None => + } + } +} diff --git a/tests/src/test/spark330/scala/org/apache/spark/sql/rapids/utils/RapidsTestConstants.scala b/tests/src/test/spark330/scala/org/apache/spark/sql/rapids/utils/RapidsTestConstants.scala new file mode 100644 index 00000000000..772becaa5f9 --- /dev/null +++ b/tests/src/test/spark330/scala/org/apache/spark/sql/rapids/utils/RapidsTestConstants.scala @@ -0,0 +1,47 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/*** spark-rapids-shim-json-lines +{"spark": "330"} +spark-rapids-shim-json-lines ***/ +package org.apache.spark.sql.rapids.utils + +import org.apache.spark.sql.types._ + +object RapidsTestConstants { + + val RAPIDS_TEST: String = "Rapids - " + + val IGNORE_ALL: String = "IGNORE_ALL" + + val SUPPORTED_DATA_TYPE = TypeCollection( + BooleanType, + ByteType, + ShortType, + IntegerType, + LongType, + FloatType, + DoubleType, + DecimalType, + StringType, + BinaryType, + DateType, + TimestampType, + ArrayType, + StructType, + MapType + ) +} diff --git a/tests/src/test/spark330/scala/org/apache/spark/sql/rapids/utils/RapidsTestSettings.scala b/tests/src/test/spark330/scala/org/apache/spark/sql/rapids/utils/RapidsTestSettings.scala new file mode 100644 index 00000000000..4981c385219 --- /dev/null +++ b/tests/src/test/spark330/scala/org/apache/spark/sql/rapids/utils/RapidsTestSettings.scala @@ -0,0 +1,81 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/*** spark-rapids-shim-json-lines +{"spark": "330"} +spark-rapids-shim-json-lines ***/ +package org.apache.spark.sql.rapids.utils + +import org.apache.spark.sql.rapids.suites.{RapidsCastSuite, RapidsDataFrameAggregateSuite, RapidsJsonFunctionsSuite, RapidsJsonSuite, RapidsMathFunctionsSuite, RapidsRegexpExpressionsSuite, RapidsStringExpressionsSuite, RapidsStringFunctionsSuite} + +// Some settings' line length exceeds 100 +// scalastyle:off line.size.limit + +class RapidsTestSettings extends BackendTestSettings { + + enableSuite[RapidsCastSuite] + .exclude("Process Infinity, -Infinity, NaN in case insensitive manner", KNOWN_ISSUE("https://github.com/NVIDIA/spark-rapids/issues/10771")) + .exclude("SPARK-35711: cast timestamp without time zone to timestamp with local time zone", KNOWN_ISSUE("https://github.com/NVIDIA/spark-rapids/issues/10771")) + .exclude("SPARK-35719: cast timestamp with local time zone to timestamp without timezone", KNOWN_ISSUE("https://github.com/NVIDIA/spark-rapids/issues/10771")) + .exclude("SPARK-35112: Cast string to day-time interval", KNOWN_ISSUE("https://github.com/NVIDIA/spark-rapids/issues/10771")) + .exclude("SPARK-35735: Take into account day-time interval fields in cast", KNOWN_ISSUE("https://github.com/NVIDIA/spark-rapids/issues/10771")) + .exclude("casting to fixed-precision decimals", KNOWN_ISSUE("https://github.com/NVIDIA/spark-rapids/issues/10771")) + .exclude("SPARK-32828: cast from a derived user-defined type to a base type", KNOWN_ISSUE("https://github.com/NVIDIA/spark-rapids/issues/10771")) + enableSuite[RapidsDataFrameAggregateSuite] + .exclude("collect functions", KNOWN_ISSUE("https://github.com/NVIDIA/spark-rapids/issues/10772")) + .exclude("collect functions structs", KNOWN_ISSUE("https://github.com/NVIDIA/spark-rapids/issues/10772")) + .exclude("collect functions should be able to cast to array type with no null values", KNOWN_ISSUE("https://github.com/NVIDIA/spark-rapids/issues/10772")) + .exclude("SPARK-17641: collect functions should not collect null values", KNOWN_ISSUE("https://github.com/NVIDIA/spark-rapids/issues/10772")) + .exclude("SPARK-19471: AggregationIterator does not initialize the generated result projection before using it", KNOWN_ISSUE("https://github.com/NVIDIA/spark-rapids/issues/10772")) + enableSuite[RapidsJsonFunctionsSuite] + enableSuite[RapidsJsonSuite] + .exclude("Casting long as timestamp", KNOWN_ISSUE("https://github.com/NVIDIA/spark-rapids/issues/10773")) + .exclude("Write timestamps correctly with timestampFormat option and timeZone option", KNOWN_ISSUE("https://github.com/NVIDIA/spark-rapids/issues/10773")) + .exclude("SPARK-23723: json in UTF-16 with BOM", KNOWN_ISSUE("https://github.com/NVIDIA/spark-rapids/issues/10773")) + .exclude("SPARK-23723: multi-line json in UTF-32BE with BOM", KNOWN_ISSUE("https://github.com/NVIDIA/spark-rapids/issues/10773")) + .exclude("SPARK-23723: Use user's encoding in reading of multi-line json in UTF-16LE", KNOWN_ISSUE("https://github.com/NVIDIA/spark-rapids/issues/10773")) + .exclude("SPARK-23723: Unsupported encoding name", KNOWN_ISSUE("https://github.com/NVIDIA/spark-rapids/issues/10773")) + .exclude("SPARK-23723: checking that the encoding option is case agnostic", KNOWN_ISSUE("https://github.com/NVIDIA/spark-rapids/issues/10773")) + .exclude("SPARK-23723: specified encoding is not matched to actual encoding", KNOWN_ISSUE("https://github.com/NVIDIA/spark-rapids/issues/10773")) + .exclude("SPARK-23724: lineSep should be set if encoding if different from UTF-8", KNOWN_ISSUE("https://github.com/NVIDIA/spark-rapids/issues/10773")) + .exclude("SPARK-31716: inferring should handle malformed input", KNOWN_ISSUE("https://github.com/NVIDIA/spark-rapids/issues/10773")) + .exclude("SPARK-24190: restrictions for JSONOptions in read", KNOWN_ISSUE("https://github.com/NVIDIA/spark-rapids/issues/10773")) + .exclude("exception mode for parsing date/timestamp string", KNOWN_ISSUE("https://github.com/NVIDIA/spark-rapids/issues/10773")) + .exclude("SPARK-37360: Timestamp type inference for a mix of TIMESTAMP_NTZ and TIMESTAMP_LTZ", KNOWN_ISSUE("https://github.com/NVIDIA/spark-rapids/issues/10773")) + enableSuite[RapidsMathFunctionsSuite] + enableSuite[RapidsRegexpExpressionsSuite] + .exclude("RegexReplace", KNOWN_ISSUE("https://github.com/NVIDIA/spark-rapids/issues/10774")) + .exclude("RegexExtract", KNOWN_ISSUE("https://github.com/NVIDIA/spark-rapids/issues/10774")) + .exclude("RegexExtractAll", KNOWN_ISSUE("https://github.com/NVIDIA/spark-rapids/issues/10774")) + .exclude("SPLIT", KNOWN_ISSUE("https://github.com/NVIDIA/spark-rapids/issues/10774")) + enableSuite[RapidsStringExpressionsSuite] + .exclude("SPARK-22498: Concat should not generate codes beyond 64KB", KNOWN_ISSUE("https://github.com/NVIDIA/spark-rapids/issues/10775")) + .exclude("SPARK-22549: ConcatWs should not generate codes beyond 64KB", KNOWN_ISSUE("https://github.com/NVIDIA/spark-rapids/issues/10775")) + .exclude("SPARK-22550: Elt should not generate codes beyond 64KB", KNOWN_ISSUE("https://github.com/NVIDIA/spark-rapids/issues/10775")) + .exclude("StringComparison", KNOWN_ISSUE("https://github.com/NVIDIA/spark-rapids/issues/10775")) + .exclude("Substring", KNOWN_ISSUE("https://github.com/NVIDIA/spark-rapids/issues/10775")) + .exclude("ascii for string", KNOWN_ISSUE("https://github.com/NVIDIA/spark-rapids/issues/10775")) + .exclude("base64/unbase64 for string", KNOWN_ISSUE("https://github.com/NVIDIA/spark-rapids/issues/10775")) + .exclude("encode/decode for string", KNOWN_ISSUE("https://github.com/NVIDIA/spark-rapids/issues/10775")) + .exclude("SPARK-22603: FormatString should not generate codes beyond 64KB", KNOWN_ISSUE("https://github.com/NVIDIA/spark-rapids/issues/10775")) + .exclude("LOCATE", KNOWN_ISSUE("https://github.com/NVIDIA/spark-rapids/issues/10775")) + .exclude("LPAD/RPAD", KNOWN_ISSUE("https://github.com/NVIDIA/spark-rapids/issues/10775")) + .exclude("REPEAT", KNOWN_ISSUE("https://github.com/NVIDIA/spark-rapids/issues/10775")) + .exclude("length for string / binary", KNOWN_ISSUE("https://github.com/NVIDIA/spark-rapids/issues/10775")) + .exclude("ParseUrl", KNOWN_ISSUE("https://github.com/NVIDIA/spark-rapids/issues/10775")) + enableSuite[RapidsStringFunctionsSuite] +} +// scalastyle:on line.size.limit diff --git a/tests/src/test/spark330/scala/org/apache/spark/sql/rapids/utils/RapidsTestsBaseTrait.scala b/tests/src/test/spark330/scala/org/apache/spark/sql/rapids/utils/RapidsTestsBaseTrait.scala new file mode 100644 index 00000000000..a3039077d90 --- /dev/null +++ b/tests/src/test/spark330/scala/org/apache/spark/sql/rapids/utils/RapidsTestsBaseTrait.scala @@ -0,0 +1,33 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/*** spark-rapids-shim-json-lines +{"spark": "330"} +spark-rapids-shim-json-lines ***/ +package org.apache.spark.sql.rapids.utils + +trait RapidsTestsBaseTrait { + + protected val rootPath: String = getClass.getResource("/").getPath + protected val basePath: String = rootPath + "unit-tests-working-home" + + protected val warehouse: String = basePath + "/spark-warehouse" + protected val metaStorePathAbsolute: String = basePath + "/meta" + + def shouldRun(testName: String): Boolean = { + BackendTestSettings.shouldRun(getClass.getCanonicalName, testName) + } +} diff --git a/tests/src/test/spark330/scala/org/apache/spark/sql/rapids/utils/RapidsTestsCommonTrait.scala b/tests/src/test/spark330/scala/org/apache/spark/sql/rapids/utils/RapidsTestsCommonTrait.scala new file mode 100644 index 00000000000..1b39073fdcf --- /dev/null +++ b/tests/src/test/spark330/scala/org/apache/spark/sql/rapids/utils/RapidsTestsCommonTrait.scala @@ -0,0 +1,73 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/*** spark-rapids-shim-json-lines +{"spark": "330"} +spark-rapids-shim-json-lines ***/ +package org.apache.spark.sql.rapids.utils + +import com.nvidia.spark.rapids.TestStats +import org.scalactic.source.Position +import org.scalatest.{Args, Status, Tag} + +import org.apache.spark.SparkFunSuite +import org.apache.spark.internal.config.Tests.IS_TESTING +import org.apache.spark.sql.catalyst.expressions.ExpressionEvalHelper +import org.apache.spark.sql.rapids.utils.RapidsTestConstants.RAPIDS_TEST + +trait RapidsTestsCommonTrait + extends SparkFunSuite + with ExpressionEvalHelper + with RapidsTestsBaseTrait { + + protected override def afterAll(): Unit = { + // SparkFunSuite will set this to true, and forget to reset to false + System.clearProperty(IS_TESTING.key) + super.afterAll() + } + + override def runTest(testName: String, args: Args): Status = { + TestStats.suiteTestNumber += 1 + TestStats.offloadRapids = true + TestStats.startCase(testName) + val status = super.runTest(testName, args) + if (TestStats.offloadRapids) { + TestStats.offloadRapidsTestNumber += 1 + print("'" + testName + "'" + " offload to RAPIDS\n") + } else { + // you can find the keyword 'Validation failed for' in function doValidate() in log + // to get the fallback reason + print("'" + testName + "'" + " NOT use RAPIDS\n") + TestStats.addFallBackCase() + } + + TestStats.endCase(status.succeeds()); + status + } + + protected def testRapids(testName: String, testTag: Tag*)(testFun: => Any)(implicit + pos: Position): Unit = { + test(RAPIDS_TEST + testName, testTag: _*)(testFun) + } + override protected def test(testName: String, testTags: Tag*)(testFun: => Any)(implicit + pos: Position): Unit = { + if (shouldRun(testName)) { + super.test(testName, testTags: _*)(testFun) + } else { + super.ignore(testName, testTags: _*)(testFun) + } + } +} diff --git a/tests/src/test/spark330/scala/org/apache/spark/sql/rapids/utils/RapidsTestsTrait.scala b/tests/src/test/spark330/scala/org/apache/spark/sql/rapids/utils/RapidsTestsTrait.scala new file mode 100644 index 00000000000..08e10f4d4fd --- /dev/null +++ b/tests/src/test/spark330/scala/org/apache/spark/sql/rapids/utils/RapidsTestsTrait.scala @@ -0,0 +1,335 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/*** spark-rapids-shim-json-lines +{"spark": "330"} +spark-rapids-shim-json-lines ***/ +package org.apache.spark.sql.rapids.utils + +import java.io.File + +import com.nvidia.spark.rapids.{GpuProjectExec, TestStats} +import org.apache.commons.io.{FileUtils => fu} +import org.apache.commons.math3.util.Precision +import org.scalactic.TripleEqualsSupport.Spread +import scala.collection.mutable +import scala.collection.mutable.ArrayBuffer + +import org.apache.spark.sql.{Column, DataFrame, Row, SparkSession} +import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} +import org.apache.spark.sql.catalyst.analysis.ResolveTimeZone +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection +import org.apache.spark.sql.catalyst.optimizer.{ConstantFolding, ConvertToLocalRelation, NullPropagation} +import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData, MapData, TypeUtils} +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.rapids.utils.RapidsQueryTestUtil.isNaNOrInf +import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.UTF8String + +trait RapidsTestsTrait extends RapidsTestsCommonTrait { + + override def beforeAll(): Unit = { + // prepare working paths + val basePathDir = new File(basePath) + if (basePathDir.exists()) { + fu.forceDelete(basePathDir) + } + fu.forceMkdir(basePathDir) + fu.forceMkdir(new File(warehouse)) + fu.forceMkdir(new File(metaStorePathAbsolute)) + super.beforeAll() + initializeSession() + _spark.sparkContext.setLogLevel("WARN") + } + + override def afterAll(): Unit = { + try { + super.afterAll() + } finally { + try { + if (_spark != null) { + try { + _spark.sessionState.catalog.reset() + } finally { + _spark.stop() + _spark = null + } + } + } finally { + SparkSession.clearActiveSession() + SparkSession.clearDefaultSession() + } + } + logInfo( + "Test suite: " + this.getClass.getSimpleName + + "; Suite test number: " + TestStats.suiteTestNumber + + "; OffloadRapids number: " + TestStats.offloadRapidsTestNumber + "\n") + TestStats.printMarkdown(this.getClass.getSimpleName) + TestStats.reset() + } + + protected def initializeSession(): Unit = { + if (_spark == null) { + val sparkBuilder = SparkSession + .builder() + .master(s"local[2]") + // Avoid static evaluation for literal input by spark catalyst. + .config( + SQLConf.OPTIMIZER_EXCLUDED_RULES.key, + ConvertToLocalRelation.ruleName + + "," + ConstantFolding.ruleName + "," + NullPropagation.ruleName) + .config("spark.rapids.sql.enabled", "true") + .config("spark.plugins", "com.nvidia.spark.SQLPlugin") + .config("spark.sql.queryExecutionListeners", + "org.apache.spark.sql.rapids.ExecutionPlanCaptureCallback") + .config("spark.sql.warehouse.dir", warehouse) + .appName("rapids spark plugin running Vanilla Spark UT") + + _spark = sparkBuilder + .config("spark.unsafe.exceptionOnMemoryLeak", "true") + .getOrCreate() + } + } + + protected var _spark: SparkSession = null + + override protected def checkEvaluation( + expression: => Expression, + expected: Any, + inputRow: InternalRow = EmptyRow): Unit = { + val resolver = ResolveTimeZone + val expr = resolver.resolveTimeZones(expression) + assert(expr.resolved) + + if (canConvertToDataFrame(inputRow)) { + rapidsCheckExpression(expr, expected, inputRow) + } else { + logWarning(s"The status of this unit test is not guaranteed.") + val catalystValue = CatalystTypeConverters.convertToCatalyst(expected) + checkEvaluationWithoutCodegen(expr, catalystValue, inputRow) + checkEvaluationWithMutableProjection(expr, catalystValue, inputRow) + if (GenerateUnsafeProjection.canSupport(expr.dataType)) { + checkEvaluationWithUnsafeProjection(expr, catalystValue, inputRow) + } + checkEvaluationWithOptimization(expr, catalystValue, inputRow) + } + } + + /** + * Sort map data by key and return the sorted key array and value array. + * + * @param input + * : input map data. + * @param kt + * : key type. + * @param vt + * : value type. + * @return + * the sorted key array and value array. + */ + private def getSortedArrays( + input: MapData, + kt: DataType, + vt: DataType): (ArrayData, ArrayData) = { + val keyArray = input.keyArray().toArray[Any](kt) + val valueArray = input.valueArray().toArray[Any](vt) + val newMap = (keyArray.zip(valueArray)).toMap + val sortedMap = mutable.SortedMap(newMap.toSeq: _*)(TypeUtils.getInterpretedOrdering(kt)) + (new GenericArrayData(sortedMap.keys.toArray), new GenericArrayData(sortedMap.values.toArray)) + } + + override protected def checkResult( + result: Any, + expected: Any, + exprDataType: DataType, + exprNullable: Boolean): Boolean = { + val dataType = UserDefinedType.sqlType(exprDataType) + + // The result is null for a non-nullable expression + assert(result != null || exprNullable, "exprNullable should be true if result is null") + (result, expected) match { + case (result: Array[Byte], expected: Array[Byte]) => + java.util.Arrays.equals(result, expected) + case (result: Double, expected: Spread[Double @unchecked]) => + expected.asInstanceOf[Spread[Double]].isWithin(result) + case (result: InternalRow, expected: InternalRow) => + val st = dataType.asInstanceOf[StructType] + assert(result.numFields == st.length && expected.numFields == st.length) + st.zipWithIndex.forall { + case (f, i) => + checkResult( + result.get(i, f.dataType), + expected.get(i, f.dataType), + f.dataType, + f.nullable) + } + case (result: ArrayData, expected: ArrayData) => + result.numElements == expected.numElements && { + val ArrayType(et, cn) = dataType.asInstanceOf[ArrayType] + var isSame = true + var i = 0 + while (isSame && i < result.numElements) { + isSame = checkResult(result.get(i, et), expected.get(i, et), et, cn) + i += 1 + } + isSame + } + case (result: MapData, expected: MapData) => + val MapType(kt, vt, vcn) = dataType.asInstanceOf[MapType] + checkResult( + getSortedArrays(result, kt, vt)._1, + getSortedArrays(expected, kt, vt)._1, + ArrayType(kt, containsNull = false), + exprNullable = false) && checkResult( + getSortedArrays(result, kt, vt)._2, + getSortedArrays(expected, kt, vt)._2, + ArrayType(vt, vcn), + exprNullable = false) + case (result: Double, expected: Double) => + if ( + (isNaNOrInf(result) || isNaNOrInf(expected)) + || (result == -0.0) || (expected == -0.0) + ) { + java.lang.Double.doubleToRawLongBits(result) == + java.lang.Double.doubleToRawLongBits(expected) + } else { + Precision.equalsWithRelativeTolerance(result, expected, 0.00001d) + } + case (result: Float, expected: Float) => + if (expected.isNaN) result.isNaN else expected == result + case (result: Row, expected: InternalRow) => result.toSeq == expected.toSeq(result.schema) + case _ => + result == expected + } + } + + def checkDataTypeSupported(expr: Expression): Boolean = { + RapidsTestConstants.SUPPORTED_DATA_TYPE.acceptsType(expr.dataType) + } + + def rapidsCheckExpression(expression: Expression, expected: Any, inputRow: InternalRow): Unit = { + val df = if (inputRow != EmptyRow && inputRow != InternalRow.empty) { + convertInternalRowToDataFrame(inputRow) + } else { + val schema = StructType(StructField("a", IntegerType, nullable = true) :: Nil) + val empData = Seq(Row(1)) + _spark.createDataFrame(_spark.sparkContext.parallelize(empData), schema) + } + val resultDF = df.select(Column(expression)) + val result = resultDF.collect() + TestStats.testUnitNumber = TestStats.testUnitNumber + 1 + if ( + checkDataTypeSupported(expression) && + expression.children.forall(checkDataTypeSupported) + ) { + val projectTransformer = resultDF.queryExecution.executedPlan.collect { + case p: GpuProjectExec => p + } + if (projectTransformer.size == 1) { + TestStats.offloadRapidsUnitNumber += 1 + logInfo("Offload to native backend in the test.\n") + } else { + logInfo("Not supported in native backend, fall back to vanilla spark in the test.\n") + shouldNotFallback() + } + } else { + logInfo("Has unsupported data type, fall back to vanilla spark.\n") + shouldNotFallback() + } + + if ( + !(checkResult(result.head.get(0), expected, expression.dataType, expression.nullable) + || checkResult( + CatalystTypeConverters.createToCatalystConverter(expression.dataType)( + result.head.get(0) + ), // decimal precision is wrong from value + CatalystTypeConverters.convertToCatalyst(expected), + expression.dataType, + expression.nullable + )) + ) { + val input = if (inputRow == EmptyRow) "" else s", input: $inputRow" + fail( + s"Incorrect evaluation: $expression, " + + s"actual: ${result.head.get(0)}, " + + s"expected: $expected$input") + } + } + + def shouldNotFallback(): Unit = { + TestStats.offloadRapids = false + } + + def canConvertToDataFrame(inputRow: InternalRow): Boolean = { + if (inputRow == EmptyRow || inputRow == InternalRow.empty) { + return true + } + if (!inputRow.isInstanceOf[GenericInternalRow]) { + return false + } + val values = inputRow.asInstanceOf[GenericInternalRow].values + for (value <- values) { + value match { + case _: MapData => return false + case _: ArrayData => return false + case _: InternalRow => return false + case _ => + } + } + true + } + + def convertInternalRowToDataFrame(inputRow: InternalRow): DataFrame = { + val structFileSeq = new ArrayBuffer[StructField]() + val values = inputRow match { + case genericInternalRow: GenericInternalRow => + genericInternalRow.values + case _ => throw new UnsupportedOperationException("Unsupported InternalRow.") + } + values.foreach { + case boolean: java.lang.Boolean => + structFileSeq.append(StructField("bool", BooleanType, boolean == null)) + case short: java.lang.Short => + structFileSeq.append(StructField("i16", ShortType, short == null)) + case byte: java.lang.Byte => + structFileSeq.append(StructField("i8", ByteType, byte == null)) + case integer: java.lang.Integer => + structFileSeq.append(StructField("i32", IntegerType, integer == null)) + case long: java.lang.Long => + structFileSeq.append(StructField("i64", LongType, long == null)) + case float: java.lang.Float => + structFileSeq.append(StructField("fp32", FloatType, float == null)) + case double: java.lang.Double => + structFileSeq.append(StructField("fp64", DoubleType, double == null)) + case utf8String: UTF8String => + structFileSeq.append(StructField("str", StringType, utf8String == null)) + case byteArr: Array[Byte] => + structFileSeq.append(StructField("vbin", BinaryType, byteArr == null)) + case decimal: Decimal => + structFileSeq.append( + StructField("dec", DecimalType(decimal.precision, decimal.scale), decimal == null)) + case null => + structFileSeq.append(StructField("null", IntegerType, nullable = true)) + case unsupported @ _ => + throw new UnsupportedOperationException(s"Unsupported type: ${unsupported.getClass}") + } + val fields = structFileSeq.toSeq + _spark.internalCreateDataFrame( + _spark.sparkContext.parallelize(Seq(inputRow)), + StructType(fields)) + } +}