From e45d5e9bc33137052c103ec2a1252afd3568364e Mon Sep 17 00:00:00 2001 From: MithunR Date: Thu, 16 Jan 2025 02:20:47 +0000 Subject: [PATCH 01/10] Support `raise_error()` on Databricks 14.3, Spark 4. Fixes #10969. This commit adds support for `raise_error()` on Databricks 14.3 and Spark 4.0. On these new Spark versions, the `RaiseError` expression (that powers the `raise_error()` API function) was changed from a Unary expression to a Binary one. This was done without modifying the arity of `raise_error()`. The ostensible reason seems to have been to eventually allow user-code to raise custom errors via `raise_error()`. This commit allows `raise_error()` to work on the GPU as it currently does on the CPU: as a unary function powered by a binary expression in the background. The tests have been modified to verify both the new behaviour and the legacy one on new platforms, while continuing to run as before on legacy platforms. Signed-off-by: MithunR --- .../src/main/python/misc_expr_test.py | 53 ++++++-- .../spark/rapids/shims/RaiseErrorShim.scala | 32 ++++- .../apache/spark/sql/rapids/shims/misc.scala | 124 ++++++++++++++++++ 3 files changed, 194 insertions(+), 15 deletions(-) create mode 100644 sql-plugin/src/main/spark350db143/scala/org/apache/spark/sql/rapids/shims/misc.scala diff --git a/integration_tests/src/main/python/misc_expr_test.py b/integration_tests/src/main/python/misc_expr_test.py index 0895d451b9d..32277234847 100644 --- a/integration_tests/src/main/python/misc_expr_test.py +++ b/integration_tests/src/main/python/misc_expr_test.py @@ -1,4 +1,4 @@ -# Copyright (c) 2020-2024, NVIDIA CORPORATION. +# Copyright (c) 2020-2025, NVIDIA CORPORATION. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -33,27 +33,58 @@ def test_part_id(): f.col('a'), f.spark_partition_id())) +# Spark conf key for choosing legacy error semantics. +legacy_semantics_key = "spark.sql.legacy.raiseErrorWithoutErrorClass" +is_new_raise_error_semantics_version=is_spark_400_or_later() or is_databricks_version_or_later(14, 3) + +def raise_error_test_impl(test_conf): + use_new_error_semantics = legacy_semantics_key in test_conf and test_conf[legacy_semantics_key] == False -@pytest.mark.skipif(condition=is_spark_400_or_later() or is_databricks_version_or_later(14, 3), - reason="raise_error() not currently implemented for Spark 4.0, or Databricks 14.3. " - "See https://github.com/NVIDIA/spark-rapids/issues/10107.") -def test_raise_error(): data_gen = ShortGen(nullable=False, min_val=0, max_val=20, special_cases=[]) assert_gpu_and_cpu_are_equal_collect( lambda spark: unary_op_df(spark, data_gen, num_slices=2).select( - f.when(f.col('a') > 30, f.raise_error("unexpected")))) + f.when(f.col('a') > 30, f.raise_error("unexpected"))), + conf=test_conf) assert_gpu_and_cpu_are_equal_collect( - lambda spark: spark.range(0).select(f.raise_error(f.col("id")))) + lambda spark: spark.range(0).select(f.raise_error(f.col("id"))), + conf=test_conf) + error_fragment = "org.apache.spark.SparkRuntimeException" if use_new_error_semantics \ + else "java.lang.RuntimeException" assert_gpu_and_cpu_error( lambda spark: unary_op_df(spark, null_gen, length=2, num_slices=1).select( f.raise_error(f.col('a'))).collect(), - conf={}, - error_message="java.lang.RuntimeException") + conf=test_conf, + error_message=error_fragment) + error_fragment = error_fragment + (": [USER_RAISED_EXCEPTION] unexpected" if use_new_error_semantics + else ": unexpected") assert_gpu_and_cpu_error( lambda spark: unary_op_df(spark, short_gen, length=2, num_slices=1).select( f.raise_error(f.lit("unexpected"))).collect(), - conf={}, - error_message="java.lang.RuntimeException: unexpected") + conf=test_conf, + error_message=error_fragment) + + +def test_raise_error_legacy_semantics(): + """ + Tests the "legacy" semantics of raise_error(), i.e. where the error + does not include an error class. + """ + if is_new_raise_error_semantics_version: + raise_error_test_impl(test_conf={legacy_semantics_key: True}) + else: + raise_error_test_impl(test_conf={}) + + +@pytest.mark.skipif(condition=not is_new_raise_error_semantics_version, + reason="New raise_error semantics (with error-class) is only available " + "on Spark 4.0 and Databricks 14.3.") +def test_raise_error_new_semantics(): + """ + Tests the "new" semantics of raise_error(), i.e. where the error + includes an error class. Unsupported in Spark versions predating + Spark 4.0, Databricks 14.3. + """ + raise_error_test_impl(test_conf={legacy_semantics_key: False}) \ No newline at end of file diff --git a/sql-plugin/src/main/spark350db143/scala/com/nvidia/spark/rapids/shims/RaiseErrorShim.scala b/sql-plugin/src/main/spark350db143/scala/com/nvidia/spark/rapids/shims/RaiseErrorShim.scala index 2d1f011b9d9..5e5a22209ad 100644 --- a/sql-plugin/src/main/spark350db143/scala/com/nvidia/spark/rapids/shims/RaiseErrorShim.scala +++ b/sql-plugin/src/main/spark350db143/scala/com/nvidia/spark/rapids/shims/RaiseErrorShim.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2024, NVIDIA CORPORATION. + * Copyright (c) 2024-2025, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -19,10 +19,34 @@ spark-rapids-shim-json-lines ***/ package com.nvidia.spark.rapids.shims -import com.nvidia.spark.rapids.ExprRule +import com.nvidia.spark.rapids.{ExprRule, GpuOverrides} +import com.nvidia.spark.rapids.{ExprChecks, GpuExpression, TypeSig, BinaryExprMeta} -import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.catalyst.expressions.{Expression, Literal, RaiseError} +import org.apache.spark.sql.rapids.shims.GpuRaiseError object RaiseErrorShim { - val exprs: Map[Class[_ <: Expression], ExprRule[_ <: Expression]] = Map.empty + val exprs: Map[Class[_ <: Expression], ExprRule[_ <: Expression]] = { + Seq(GpuOverrides.expr[RaiseError]( + "Throw an exception", + ExprChecks.binaryProject( + TypeSig.NULL, TypeSig.NULL, + ("errorClass", TypeSig.STRING, TypeSig.STRING), + ("errorParams", TypeSig.MAP.nested(TypeSig.STRING), TypeSig.MAP.nested(TypeSig.STRING)), + ), + (a, conf, p, r) => new BinaryExprMeta[RaiseError](a, conf, p, r) { + + override def tagExprForGpu(): Unit = { + // In Databricks 14.3 and Spark 4.0, RaiseError forwards the lhs expression (i.e. the error-class) + // as a scalar value. A vector/column here would be surprising. + a.errorClass match { + case _: Literal => // Supported. + case _ => willNotWorkOnGpu(s"expected error-class to be a STRING literal") + } + } + + override def convertToGpu(lhsErrorClass: Expression, rhsErrorParams: Expression): GpuExpression = + GpuRaiseError(lhsErrorClass, rhsErrorParams) + })).map(r => (r.getClassFor.asSubclass(classOf[Expression]), r)).toMap + } } diff --git a/sql-plugin/src/main/spark350db143/scala/org/apache/spark/sql/rapids/shims/misc.scala b/sql-plugin/src/main/spark350db143/scala/org/apache/spark/sql/rapids/shims/misc.scala new file mode 100644 index 00000000000..d45c9d32eeb --- /dev/null +++ b/sql-plugin/src/main/spark350db143/scala/org/apache/spark/sql/rapids/shims/misc.scala @@ -0,0 +1,124 @@ +/* + * Copyright (c) 2025, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +/*** spark-rapids-shim-json-lines +{"spark": "350db143"} +{"spark": "400"} +spark-rapids-shim-json-lines ***/ +package org.apache.spark.sql.rapids.shims + +import ai.rapids.cudf.{ColumnVector, ColumnView, Scalar} +import com.nvidia.spark.rapids.{GpuColumnVector, GpuBinaryExpression, GpuMapUtils, GpuScalar} +import com.nvidia.spark.rapids.Arm.withResource + +import org.apache.spark.sql.catalyst.expressions.{ExpectsInputTypes, Expression} +import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, MapData} +import org.apache.spark.sql.errors.QueryExecutionErrors.raiseError +import org.apache.spark.sql.types.{AbstractDataType, DataType, NullType, StringType} +import org.apache.spark.unsafe.types.UTF8String + +/** + * Implements `raise_error()` for Databricks 14.3 and Spark 4.0. + * Note that while the arity `raise_error()` remains 1 for all user-facing APIs (SQL, Scala, Python). + * But internally, the implementation uses a binary expression, where the first argument indicates + * the "error-class" for the error being raised. + */ +case class GpuRaiseError(left: Expression, right: Expression) extends GpuBinaryExpression with ExpectsInputTypes { + + val errorClass: Expression = left + val errorParams: Expression = right + + override def dataType: DataType = NullType + override def inputTypes: Seq[AbstractDataType] = Seq(StringType) + override def toString: String = s"raise_error($errorClass, $errorParams)" + + /** Could evaluating this expression cause side-effects, such as throwing an exception? */ + override def hasSideEffects: Boolean = true + + override def doColumnar(lhs: GpuColumnVector, rhs: GpuColumnVector): ColumnVector = + throw new UnsupportedOperationException("Expected errorClass (lhs) to be a String literal") + + override def doColumnar(lhs: GpuColumnVector, rhs: GpuScalar): ColumnVector = + throw new UnsupportedOperationException("Expected errorClass (lhs) to be a String literal") + + private def extractScalaUTF8String(stringScalar: Scalar): UTF8String = { + // This is guaranteed to be a string scalar. + GpuScalar.extract(stringScalar).asInstanceOf[UTF8String] + } + + private def extractStrings(stringsColumn: ColumnView): Array[UTF8String] = { + val size = stringsColumn.getRowCount.asInstanceOf[Int] // Already checked if exceeds threshold. + val output: Array[UTF8String] = new Array[UTF8String](size) + for (i <- 0 until size) { + output(i) = withResource(stringsColumn.getScalarElement(i)) { + extractScalaUTF8String(_) + } + } + output + } + + private def makeMapData(listOfStructs: ColumnView): MapData = { + val THRESHOLD: Int = 10 // Avoiding surprises with large maps. All testing indicates a map with 1 entry. + val mapSize = listOfStructs.getRowCount + + if (mapSize > THRESHOLD) + throw new UnsupportedOperationException("Unexpectedly large error-parameter map") + + val outputKeys: Array[UTF8String] = + withResource(GpuMapUtils.getKeysAsListView(listOfStructs)) { listOfKeys => + withResource(listOfKeys.getChildColumnView(0)) { // Strings child of LIST column. + extractStrings(_) + } + } + + val outputVals: Array[UTF8String] = + withResource(GpuMapUtils.getValuesAsListView(listOfStructs)) { listOfVals => + withResource(listOfVals.getChildColumnView(0)) { // Strings child of LIST column. + extractStrings(_) + } + } + + ArrayBasedMapData(outputKeys, outputVals) + } + + override def doColumnar(lhs: GpuScalar, rhs: GpuColumnVector): ColumnVector = { + if (rhs.getRowCount <= 0) { + // For the case: when(condition, raise_error(col("a")) + // When `condition` selects no rows, a vector of nulls should be returned, instead of throwing. + return GpuColumnVector.columnVectorFromNull(0, NullType) + } + + val lhsErrorClass = lhs.getValue.asInstanceOf[UTF8String] + + val rhsMapData = withResource(rhs.getBase.slice(0,1)) { slices => + val firstRhsRow = slices(0) + makeMapData(firstRhsRow) + } + + throw raiseError(lhsErrorClass, rhsMapData) + } + + override def doColumnar(numRows: Int, lhs: GpuScalar, rhs: GpuScalar): ColumnVector = { + if (numRows <= 0) { + // For the case: when(condition, raise_error(col("a")) + // When `condition` selects no rows, a vector of nulls should be returned, instead of throwing. + return GpuColumnVector.columnVectorFromNull(0, NullType) + } + + val errorClass = lhs.getValue.asInstanceOf[UTF8String] + val errorParams = rhs.getValue.asInstanceOf[MapData] + throw raiseError(errorClass, errorParams) + } +} From 24008a7d8cd26c280fd4afe5b8d330487fa091d7 Mon Sep 17 00:00:00 2001 From: MithunR Date: Thu, 16 Jan 2025 08:11:06 -0800 Subject: [PATCH 02/10] Style fixes. --- .../spark/rapids/shims/RaiseErrorShim.scala | 8 ++++---- .../apache/spark/sql/rapids/shims/misc.scala | 18 +++++++++++------- 2 files changed, 15 insertions(+), 11 deletions(-) diff --git a/sql-plugin/src/main/spark350db143/scala/com/nvidia/spark/rapids/shims/RaiseErrorShim.scala b/sql-plugin/src/main/spark350db143/scala/com/nvidia/spark/rapids/shims/RaiseErrorShim.scala index 5e5a22209ad..b9ac1f1c155 100644 --- a/sql-plugin/src/main/spark350db143/scala/com/nvidia/spark/rapids/shims/RaiseErrorShim.scala +++ b/sql-plugin/src/main/spark350db143/scala/com/nvidia/spark/rapids/shims/RaiseErrorShim.scala @@ -37,16 +37,16 @@ object RaiseErrorShim { (a, conf, p, r) => new BinaryExprMeta[RaiseError](a, conf, p, r) { override def tagExprForGpu(): Unit = { - // In Databricks 14.3 and Spark 4.0, RaiseError forwards the lhs expression (i.e. the error-class) - // as a scalar value. A vector/column here would be surprising. + // In Databricks 14.3 and Spark 4.0, RaiseError forwards the lhs expression + // (i.e. the error-class) as a scalar value. A vector/column here would be surprising. a.errorClass match { case _: Literal => // Supported. case _ => willNotWorkOnGpu(s"expected error-class to be a STRING literal") } } - override def convertToGpu(lhsErrorClass: Expression, rhsErrorParams: Expression): GpuExpression = - GpuRaiseError(lhsErrorClass, rhsErrorParams) + override def convertToGpu(lhsErrorClass: Expression, rhsErrorParams: Expression) + : GpuExpression = GpuRaiseError(lhsErrorClass, rhsErrorParams) })).map(r => (r.getClassFor.asSubclass(classOf[Expression]), r)).toMap } } diff --git a/sql-plugin/src/main/spark350db143/scala/org/apache/spark/sql/rapids/shims/misc.scala b/sql-plugin/src/main/spark350db143/scala/org/apache/spark/sql/rapids/shims/misc.scala index d45c9d32eeb..998ac47c9a1 100644 --- a/sql-plugin/src/main/spark350db143/scala/org/apache/spark/sql/rapids/shims/misc.scala +++ b/sql-plugin/src/main/spark350db143/scala/org/apache/spark/sql/rapids/shims/misc.scala @@ -31,11 +31,12 @@ import org.apache.spark.unsafe.types.UTF8String /** * Implements `raise_error()` for Databricks 14.3 and Spark 4.0. - * Note that while the arity `raise_error()` remains 1 for all user-facing APIs (SQL, Scala, Python). - * But internally, the implementation uses a binary expression, where the first argument indicates - * the "error-class" for the error being raised. + * Note that while the arity `raise_error()` remains 1 for all user-facing APIs + * (SQL, Scala, Python). But internally, the implementation uses a binary expression, + * where the first argument indicates the "error-class" for the error being raised. */ -case class GpuRaiseError(left: Expression, right: Expression) extends GpuBinaryExpression with ExpectsInputTypes { +case class GpuRaiseError(left: Expression, right: Expression) extends GpuBinaryExpression + with ExpectsInputTypes { val errorClass: Expression = left val errorParams: Expression = right @@ -70,7 +71,8 @@ case class GpuRaiseError(left: Expression, right: Expression) extends GpuBinaryE } private def makeMapData(listOfStructs: ColumnView): MapData = { - val THRESHOLD: Int = 10 // Avoiding surprises with large maps. All testing indicates a map with 1 entry. + val THRESHOLD: Int = 10 // Avoiding surprises with large maps. + // All testing indicates a map with 1 entry. val mapSize = listOfStructs.getRowCount if (mapSize > THRESHOLD) @@ -96,7 +98,8 @@ case class GpuRaiseError(left: Expression, right: Expression) extends GpuBinaryE override def doColumnar(lhs: GpuScalar, rhs: GpuColumnVector): ColumnVector = { if (rhs.getRowCount <= 0) { // For the case: when(condition, raise_error(col("a")) - // When `condition` selects no rows, a vector of nulls should be returned, instead of throwing. + // When `condition` selects no rows, a vector of nulls should be returned, + // instead of throwing. return GpuColumnVector.columnVectorFromNull(0, NullType) } @@ -113,7 +116,8 @@ case class GpuRaiseError(left: Expression, right: Expression) extends GpuBinaryE override def doColumnar(numRows: Int, lhs: GpuScalar, rhs: GpuScalar): ColumnVector = { if (numRows <= 0) { // For the case: when(condition, raise_error(col("a")) - // When `condition` selects no rows, a vector of nulls should be returned, instead of throwing. + // When `condition` selects no rows, a vector of nulls should be returned, + // instead of throwing. return GpuColumnVector.columnVectorFromNull(0, NullType) } From f6aba728210ee53644c1a132bfd61a63a979aa06 Mon Sep 17 00:00:00 2001 From: MithunR Date: Thu, 16 Jan 2025 08:43:22 -0800 Subject: [PATCH 03/10] More style fixes. --- .../scala/com/nvidia/spark/rapids/shims/RaiseErrorShim.scala | 2 +- .../scala/org/apache/spark/sql/rapids/shims/misc.scala | 5 +++-- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/sql-plugin/src/main/spark350db143/scala/com/nvidia/spark/rapids/shims/RaiseErrorShim.scala b/sql-plugin/src/main/spark350db143/scala/com/nvidia/spark/rapids/shims/RaiseErrorShim.scala index b9ac1f1c155..98563448b6d 100644 --- a/sql-plugin/src/main/spark350db143/scala/com/nvidia/spark/rapids/shims/RaiseErrorShim.scala +++ b/sql-plugin/src/main/spark350db143/scala/com/nvidia/spark/rapids/shims/RaiseErrorShim.scala @@ -20,7 +20,7 @@ spark-rapids-shim-json-lines ***/ package com.nvidia.spark.rapids.shims import com.nvidia.spark.rapids.{ExprRule, GpuOverrides} -import com.nvidia.spark.rapids.{ExprChecks, GpuExpression, TypeSig, BinaryExprMeta} +import com.nvidia.spark.rapids.{BinaryExprMeta, ExprChecks, GpuExpression, TypeSig} import org.apache.spark.sql.catalyst.expressions.{Expression, Literal, RaiseError} import org.apache.spark.sql.rapids.shims.GpuRaiseError diff --git a/sql-plugin/src/main/spark350db143/scala/org/apache/spark/sql/rapids/shims/misc.scala b/sql-plugin/src/main/spark350db143/scala/org/apache/spark/sql/rapids/shims/misc.scala index 998ac47c9a1..70e14a4ae8d 100644 --- a/sql-plugin/src/main/spark350db143/scala/org/apache/spark/sql/rapids/shims/misc.scala +++ b/sql-plugin/src/main/spark350db143/scala/org/apache/spark/sql/rapids/shims/misc.scala @@ -20,7 +20,7 @@ spark-rapids-shim-json-lines ***/ package org.apache.spark.sql.rapids.shims import ai.rapids.cudf.{ColumnVector, ColumnView, Scalar} -import com.nvidia.spark.rapids.{GpuColumnVector, GpuBinaryExpression, GpuMapUtils, GpuScalar} +import com.nvidia.spark.rapids.{GpuBinaryExpression, GpuColumnVector, GpuMapUtils, GpuScalar} import com.nvidia.spark.rapids.Arm.withResource import org.apache.spark.sql.catalyst.expressions.{ExpectsInputTypes, Expression} @@ -75,8 +75,9 @@ case class GpuRaiseError(left: Expression, right: Expression) extends GpuBinaryE // All testing indicates a map with 1 entry. val mapSize = listOfStructs.getRowCount - if (mapSize > THRESHOLD) + if (mapSize > THRESHOLD) { throw new UnsupportedOperationException("Unexpectedly large error-parameter map") + } val outputKeys: Array[UTF8String] = withResource(GpuMapUtils.getKeysAsListView(listOfStructs)) { listOfKeys => From 0f117012dcb4d9636be221d5399f72a30c2cef4d Mon Sep 17 00:00:00 2001 From: MithunR Date: Thu, 16 Jan 2025 10:26:51 -0800 Subject: [PATCH 04/10] Suppported ops/exprs update. --- tools/generated_files/400/operatorsScore.csv | 1 + tools/generated_files/400/supportedExprs.csv | 3 +++ 2 files changed, 4 insertions(+) diff --git a/tools/generated_files/400/operatorsScore.csv b/tools/generated_files/400/operatorsScore.csv index 0a099fc2233..a3b77fffb82 100644 --- a/tools/generated_files/400/operatorsScore.csv +++ b/tools/generated_files/400/operatorsScore.csv @@ -218,6 +218,7 @@ PythonUDAF,4 PythonUDF,4 Quarter,4 RLike,4 +RaiseError,4 Rand,4 Rank,4 RegExpExtract,4 diff --git a/tools/generated_files/400/supportedExprs.csv b/tools/generated_files/400/supportedExprs.csv index 926bb4f6c36..9ece9101924 100644 --- a/tools/generated_files/400/supportedExprs.csv +++ b/tools/generated_files/400/supportedExprs.csv @@ -461,6 +461,9 @@ Quarter,S,`quarter`,None,project,result,NA,NA,NA,S,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA RLike,S,`regexp_like`; `regexp`; `rlike`,None,project,str,NA,NA,NA,NA,NA,NA,NA,NA,NA,S,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA RLike,S,`regexp_like`; `regexp`; `rlike`,None,project,regexp,NA,NA,NA,NA,NA,NA,NA,NA,NA,PS,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA RLike,S,`regexp_like`; `regexp`; `rlike`,None,project,result,S,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA +RaiseError,S, ,None,project,errorClass,NA,NA,NA,NA,NA,NA,NA,NA,NA,S,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA +RaiseError,S, ,None,project,errorParams,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,S,NA,NA,NA,NA +RaiseError,S, ,None,project,result,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,S,NA,NA,NA,NA,NA,NA,NA,NA Rand,S,`rand`; `random`,None,project,seed,NA,NA,NA,S,S,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA Rand,S,`rand`; `random`,None,project,result,NA,NA,NA,NA,NA,NA,S,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA Rank,S,`rank`,None,window,ordering,S,S,S,S,S,S,S,S,PS,S,S,S,NS,NS,NS,NS,NS,NS,NS,NS From c87b1142a569188e639e118aea7c022e02d591d7 Mon Sep 17 00:00:00 2001 From: MithunR Date: Fri, 17 Jan 2025 04:50:46 +0000 Subject: [PATCH 05/10] Better use of ExprChecks. --- .../spark/rapids/shims/RaiseErrorShim.scala | 18 +++++------------- 1 file changed, 5 insertions(+), 13 deletions(-) diff --git a/sql-plugin/src/main/spark350db143/scala/com/nvidia/spark/rapids/shims/RaiseErrorShim.scala b/sql-plugin/src/main/spark350db143/scala/com/nvidia/spark/rapids/shims/RaiseErrorShim.scala index 98563448b6d..b09b42a9fc0 100644 --- a/sql-plugin/src/main/spark350db143/scala/com/nvidia/spark/rapids/shims/RaiseErrorShim.scala +++ b/sql-plugin/src/main/spark350db143/scala/com/nvidia/spark/rapids/shims/RaiseErrorShim.scala @@ -20,9 +20,9 @@ spark-rapids-shim-json-lines ***/ package com.nvidia.spark.rapids.shims import com.nvidia.spark.rapids.{ExprRule, GpuOverrides} -import com.nvidia.spark.rapids.{BinaryExprMeta, ExprChecks, GpuExpression, TypeSig} +import com.nvidia.spark.rapids.{BinaryExprMeta, ExprChecks, GpuExpression, TypeEnum, TypeSig} -import org.apache.spark.sql.catalyst.expressions.{Expression, Literal, RaiseError} +import org.apache.spark.sql.catalyst.expressions.{Expression, RaiseError} import org.apache.spark.sql.rapids.shims.GpuRaiseError object RaiseErrorShim { @@ -31,20 +31,12 @@ object RaiseErrorShim { "Throw an exception", ExprChecks.binaryProject( TypeSig.NULL, TypeSig.NULL, - ("errorClass", TypeSig.STRING, TypeSig.STRING), + // In Databricks 14.3 and Spark 4.0, RaiseError forwards the lhs expression + // (i.e. the error-class) as a scalar value. A vector/column here would be surprising. + ("errorClass", TypeSig.lit(TypeEnum.STRING), TypeSig.STRING), ("errorParams", TypeSig.MAP.nested(TypeSig.STRING), TypeSig.MAP.nested(TypeSig.STRING)), ), (a, conf, p, r) => new BinaryExprMeta[RaiseError](a, conf, p, r) { - - override def tagExprForGpu(): Unit = { - // In Databricks 14.3 and Spark 4.0, RaiseError forwards the lhs expression - // (i.e. the error-class) as a scalar value. A vector/column here would be surprising. - a.errorClass match { - case _: Literal => // Supported. - case _ => willNotWorkOnGpu(s"expected error-class to be a STRING literal") - } - } - override def convertToGpu(lhsErrorClass: Expression, rhsErrorParams: Expression) : GpuExpression = GpuRaiseError(lhsErrorClass, rhsErrorParams) })).map(r => (r.getClassFor.asSubclass(classOf[Expression]), r)).toMap From 03527433ae70437633502a98af46bfd50be97af1 Mon Sep 17 00:00:00 2001 From: MithunR Date: Fri, 17 Jan 2025 05:57:22 +0000 Subject: [PATCH 06/10] Fixed the return data-type for RaiseError. --- .../src/main/python/misc_expr_test.py | 16 +++++++++++++++- .../spark/rapids/shims/RaiseErrorShim.scala | 2 +- .../org/apache/spark/sql/rapids/shims/misc.scala | 5 ++--- 3 files changed, 18 insertions(+), 5 deletions(-) diff --git a/integration_tests/src/main/python/misc_expr_test.py b/integration_tests/src/main/python/misc_expr_test.py index 32277234847..d24dfe24fad 100644 --- a/integration_tests/src/main/python/misc_expr_test.py +++ b/integration_tests/src/main/python/misc_expr_test.py @@ -14,7 +14,7 @@ import pytest -from asserts import assert_gpu_and_cpu_are_equal_collect, assert_gpu_and_cpu_error +from asserts import assert_gpu_and_cpu_are_equal_collect, assert_gpu_and_cpu_are_equal_sql, assert_gpu_and_cpu_error from data_gen import * from marks import incompat, approximate_float from pyspark.sql.types import * @@ -41,11 +41,25 @@ def raise_error_test_impl(test_conf): use_new_error_semantics = legacy_semantics_key in test_conf and test_conf[legacy_semantics_key] == False data_gen = ShortGen(nullable=False, min_val=0, max_val=20, special_cases=[]) + + # Test for "when" selecting the "raise_error()" expression (null-type). assert_gpu_and_cpu_are_equal_collect( lambda spark: unary_op_df(spark, data_gen, num_slices=2).select( f.when(f.col('a') > 30, f.raise_error("unexpected"))), conf=test_conf) + # Test for if/else, with raise_error in the else. + # This should test if the data-type of raise_error() interferes with + # the result-type of the parent expression (if/else). + assert_gpu_and_cpu_are_equal_sql( + lambda spark: unary_op_df(spark, data_gen, num_slices=2), + 'test_table', + """ + SELECT IF( a < 30, a, raise_error('unexpected') ) + FROM test_table + """, + conf=test_conf) + assert_gpu_and_cpu_are_equal_collect( lambda spark: spark.range(0).select(f.raise_error(f.col("id"))), conf=test_conf) diff --git a/sql-plugin/src/main/spark350db143/scala/com/nvidia/spark/rapids/shims/RaiseErrorShim.scala b/sql-plugin/src/main/spark350db143/scala/com/nvidia/spark/rapids/shims/RaiseErrorShim.scala index b09b42a9fc0..f318a2f1500 100644 --- a/sql-plugin/src/main/spark350db143/scala/com/nvidia/spark/rapids/shims/RaiseErrorShim.scala +++ b/sql-plugin/src/main/spark350db143/scala/com/nvidia/spark/rapids/shims/RaiseErrorShim.scala @@ -38,7 +38,7 @@ object RaiseErrorShim { ), (a, conf, p, r) => new BinaryExprMeta[RaiseError](a, conf, p, r) { override def convertToGpu(lhsErrorClass: Expression, rhsErrorParams: Expression) - : GpuExpression = GpuRaiseError(lhsErrorClass, rhsErrorParams) + : GpuExpression = GpuRaiseError(lhsErrorClass, rhsErrorParams, a.dataType) })).map(r => (r.getClassFor.asSubclass(classOf[Expression]), r)).toMap } } diff --git a/sql-plugin/src/main/spark350db143/scala/org/apache/spark/sql/rapids/shims/misc.scala b/sql-plugin/src/main/spark350db143/scala/org/apache/spark/sql/rapids/shims/misc.scala index 70e14a4ae8d..ade1c910a8a 100644 --- a/sql-plugin/src/main/spark350db143/scala/org/apache/spark/sql/rapids/shims/misc.scala +++ b/sql-plugin/src/main/spark350db143/scala/org/apache/spark/sql/rapids/shims/misc.scala @@ -35,13 +35,12 @@ import org.apache.spark.unsafe.types.UTF8String * (SQL, Scala, Python). But internally, the implementation uses a binary expression, * where the first argument indicates the "error-class" for the error being raised. */ -case class GpuRaiseError(left: Expression, right: Expression) extends GpuBinaryExpression - with ExpectsInputTypes { +case class GpuRaiseError(left: Expression, right: Expression, dataType: DataType) + extends GpuBinaryExpression with ExpectsInputTypes { val errorClass: Expression = left val errorParams: Expression = right - override def dataType: DataType = NullType override def inputTypes: Seq[AbstractDataType] = Seq(StringType) override def toString: String = s"raise_error($errorClass, $errorParams)" From a1f142940859a92d389e5684e28c04dd31211820 Mon Sep 17 00:00:00 2001 From: MithunR Date: Fri, 17 Jan 2025 06:20:44 +0000 Subject: [PATCH 07/10] Inlined Spark version checks. --- integration_tests/src/main/python/misc_expr_test.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/integration_tests/src/main/python/misc_expr_test.py b/integration_tests/src/main/python/misc_expr_test.py index d24dfe24fad..04c39261349 100644 --- a/integration_tests/src/main/python/misc_expr_test.py +++ b/integration_tests/src/main/python/misc_expr_test.py @@ -35,7 +35,6 @@ def test_part_id(): # Spark conf key for choosing legacy error semantics. legacy_semantics_key = "spark.sql.legacy.raiseErrorWithoutErrorClass" -is_new_raise_error_semantics_version=is_spark_400_or_later() or is_databricks_version_or_later(14, 3) def raise_error_test_impl(test_conf): use_new_error_semantics = legacy_semantics_key in test_conf and test_conf[legacy_semantics_key] == False @@ -86,15 +85,19 @@ def test_raise_error_legacy_semantics(): Tests the "legacy" semantics of raise_error(), i.e. where the error does not include an error class. """ - if is_new_raise_error_semantics_version: + if is_spark_400_or_later() or is_databricks_version_or_later(14, 3): + # Spark 4+ and Databricks 14.3+ support RaiseError with error-classes included. + # Must test "legacy" mode, where error-classes are excluded. raise_error_test_impl(test_conf={legacy_semantics_key: True}) else: + # Spark versions preceding 4.0, or Databricks 14.3 do not support RaiseError with + # error-classes. No legacy mode need be selected. raise_error_test_impl(test_conf={}) -@pytest.mark.skipif(condition=not is_new_raise_error_semantics_version, - reason="New raise_error semantics (with error-class) is only available " - "on Spark 4.0 and Databricks 14.3.") +@pytest.mark.skipif(condition=not (is_spark_400_or_later() or is_databricks_version_or_later(14, 3)), + reason="RaiseError semantics with error-classes are only supported " + "on Spark 4.0+ and Databricks 14.3+.") def test_raise_error_new_semantics(): """ Tests the "new" semantics of raise_error(), i.e. where the error From 2da72b9904978f0c1ee1ec156e5955238ad96f65 Mon Sep 17 00:00:00 2001 From: MithunR Date: Fri, 17 Jan 2025 06:57:41 +0000 Subject: [PATCH 08/10] Moved RaiseError's meta to its own class. --- .../nvidia/spark/rapids/shims/RaiseErrorShim.scala | 10 ++++------ .../org/apache/spark/sql/rapids/shims/misc.scala | 13 +++++++++++-- 2 files changed, 15 insertions(+), 8 deletions(-) diff --git a/sql-plugin/src/main/spark350db143/scala/com/nvidia/spark/rapids/shims/RaiseErrorShim.scala b/sql-plugin/src/main/spark350db143/scala/com/nvidia/spark/rapids/shims/RaiseErrorShim.scala index f318a2f1500..3ad318c3441 100644 --- a/sql-plugin/src/main/spark350db143/scala/com/nvidia/spark/rapids/shims/RaiseErrorShim.scala +++ b/sql-plugin/src/main/spark350db143/scala/com/nvidia/spark/rapids/shims/RaiseErrorShim.scala @@ -20,10 +20,10 @@ spark-rapids-shim-json-lines ***/ package com.nvidia.spark.rapids.shims import com.nvidia.spark.rapids.{ExprRule, GpuOverrides} -import com.nvidia.spark.rapids.{BinaryExprMeta, ExprChecks, GpuExpression, TypeEnum, TypeSig} +import com.nvidia.spark.rapids.{ExprChecks, TypeEnum, TypeSig} import org.apache.spark.sql.catalyst.expressions.{Expression, RaiseError} -import org.apache.spark.sql.rapids.shims.GpuRaiseError +import org.apache.spark.sql.rapids.shims.RaiseErrorMeta object RaiseErrorShim { val exprs: Map[Class[_ <: Expression], ExprRule[_ <: Expression]] = { @@ -36,9 +36,7 @@ object RaiseErrorShim { ("errorClass", TypeSig.lit(TypeEnum.STRING), TypeSig.STRING), ("errorParams", TypeSig.MAP.nested(TypeSig.STRING), TypeSig.MAP.nested(TypeSig.STRING)), ), - (a, conf, p, r) => new BinaryExprMeta[RaiseError](a, conf, p, r) { - override def convertToGpu(lhsErrorClass: Expression, rhsErrorParams: Expression) - : GpuExpression = GpuRaiseError(lhsErrorClass, rhsErrorParams, a.dataType) - })).map(r => (r.getClassFor.asSubclass(classOf[Expression]), r)).toMap + (a, conf, p, r) => new RaiseErrorMeta(a, conf, p, r) + )).map(r => (r.getClassFor.asSubclass(classOf[Expression]), r)).toMap } } diff --git a/sql-plugin/src/main/spark350db143/scala/org/apache/spark/sql/rapids/shims/misc.scala b/sql-plugin/src/main/spark350db143/scala/org/apache/spark/sql/rapids/shims/misc.scala index ade1c910a8a..d907bf47cd9 100644 --- a/sql-plugin/src/main/spark350db143/scala/org/apache/spark/sql/rapids/shims/misc.scala +++ b/sql-plugin/src/main/spark350db143/scala/org/apache/spark/sql/rapids/shims/misc.scala @@ -20,10 +20,10 @@ spark-rapids-shim-json-lines ***/ package org.apache.spark.sql.rapids.shims import ai.rapids.cudf.{ColumnVector, ColumnView, Scalar} -import com.nvidia.spark.rapids.{GpuBinaryExpression, GpuColumnVector, GpuMapUtils, GpuScalar} +import com.nvidia.spark.rapids.{BinaryExprMeta, DataFromReplacementRule, GpuBinaryExpression, GpuColumnVector, GpuExpression, GpuMapUtils, GpuScalar, RapidsConf, RapidsMeta} import com.nvidia.spark.rapids.Arm.withResource -import org.apache.spark.sql.catalyst.expressions.{ExpectsInputTypes, Expression} +import org.apache.spark.sql.catalyst.expressions.{ExpectsInputTypes, Expression, RaiseError} import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, MapData} import org.apache.spark.sql.errors.QueryExecutionErrors.raiseError import org.apache.spark.sql.types.{AbstractDataType, DataType, NullType, StringType} @@ -126,3 +126,12 @@ case class GpuRaiseError(left: Expression, right: Expression, dataType: DataType throw raiseError(errorClass, errorParams) } } + +class RaiseErrorMeta(r: RaiseError, + conf: RapidsConf, + parent: Option[RapidsMeta[_, _, _]], + rule: DataFromReplacementRule ) + extends BinaryExprMeta[RaiseError](r, conf, parent, rule) { + override def convertToGpu(lhsErrorClass: Expression, rhsErrorParams: Expression): GpuExpression + = GpuRaiseError(lhsErrorClass, rhsErrorParams, r.dataType) +} \ No newline at end of file From 5b9037451900a5410064cc6accbb4b474ddfea4c Mon Sep 17 00:00:00 2001 From: MithunR Date: Fri, 17 Jan 2025 07:17:42 +0000 Subject: [PATCH 09/10] Documented the use of GpuScalar::getValue. --- .../scala/org/apache/spark/sql/rapids/shims/misc.scala | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/sql-plugin/src/main/spark350db143/scala/org/apache/spark/sql/rapids/shims/misc.scala b/sql-plugin/src/main/spark350db143/scala/org/apache/spark/sql/rapids/shims/misc.scala index d907bf47cd9..ff580829ee8 100644 --- a/sql-plugin/src/main/spark350db143/scala/org/apache/spark/sql/rapids/shims/misc.scala +++ b/sql-plugin/src/main/spark350db143/scala/org/apache/spark/sql/rapids/shims/misc.scala @@ -122,6 +122,10 @@ case class GpuRaiseError(left: Expression, right: Expression, dataType: DataType } val errorClass = lhs.getValue.asInstanceOf[UTF8String] + // TODO (future): Check if the map-data needs to be extracted differently. + // All testing indicates that the host value of the map literal is set always pre-set. + // But if it isn't, then GpuScalar.getValue might extract it incorrectly. + // https://github.com/NVIDIA/spark-rapids/issues/11974 val errorParams = rhs.getValue.asInstanceOf[MapData] throw raiseError(errorClass, errorParams) } From fb101b3252d4fa0003916234dcf29fef18032278 Mon Sep 17 00:00:00 2001 From: MithunR Date: Thu, 16 Jan 2025 23:30:15 -0800 Subject: [PATCH 10/10] Updated supportedExprs.csv --- tools/generated_files/400/supportedExprs.csv | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tools/generated_files/400/supportedExprs.csv b/tools/generated_files/400/supportedExprs.csv index 9ece9101924..2e8c37f7f45 100644 --- a/tools/generated_files/400/supportedExprs.csv +++ b/tools/generated_files/400/supportedExprs.csv @@ -461,7 +461,7 @@ Quarter,S,`quarter`,None,project,result,NA,NA,NA,S,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA RLike,S,`regexp_like`; `regexp`; `rlike`,None,project,str,NA,NA,NA,NA,NA,NA,NA,NA,NA,S,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA RLike,S,`regexp_like`; `regexp`; `rlike`,None,project,regexp,NA,NA,NA,NA,NA,NA,NA,NA,NA,PS,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA RLike,S,`regexp_like`; `regexp`; `rlike`,None,project,result,S,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA -RaiseError,S, ,None,project,errorClass,NA,NA,NA,NA,NA,NA,NA,NA,NA,S,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA +RaiseError,S, ,None,project,errorClass,NA,NA,NA,NA,NA,NA,NA,NA,NA,PS,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA RaiseError,S, ,None,project,errorParams,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,S,NA,NA,NA,NA RaiseError,S, ,None,project,result,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,S,NA,NA,NA,NA,NA,NA,NA,NA Rand,S,`rand`; `random`,None,project,seed,NA,NA,NA,S,S,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA