diff --git a/integration_tests/src/main/python/misc_expr_test.py b/integration_tests/src/main/python/misc_expr_test.py index 0895d451b9d..04c39261349 100644 --- a/integration_tests/src/main/python/misc_expr_test.py +++ b/integration_tests/src/main/python/misc_expr_test.py @@ -1,4 +1,4 @@ -# Copyright (c) 2020-2024, NVIDIA CORPORATION. +# Copyright (c) 2020-2025, NVIDIA CORPORATION. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -14,7 +14,7 @@ import pytest -from asserts import assert_gpu_and_cpu_are_equal_collect, assert_gpu_and_cpu_error +from asserts import assert_gpu_and_cpu_are_equal_collect, assert_gpu_and_cpu_are_equal_sql, assert_gpu_and_cpu_error from data_gen import * from marks import incompat, approximate_float from pyspark.sql.types import * @@ -33,27 +33,75 @@ def test_part_id(): f.col('a'), f.spark_partition_id())) +# Spark conf key for choosing legacy error semantics. +legacy_semantics_key = "spark.sql.legacy.raiseErrorWithoutErrorClass" + +def raise_error_test_impl(test_conf): + use_new_error_semantics = legacy_semantics_key in test_conf and test_conf[legacy_semantics_key] == False -@pytest.mark.skipif(condition=is_spark_400_or_later() or is_databricks_version_or_later(14, 3), - reason="raise_error() not currently implemented for Spark 4.0, or Databricks 14.3. " - "See https://github.com/NVIDIA/spark-rapids/issues/10107.") -def test_raise_error(): data_gen = ShortGen(nullable=False, min_val=0, max_val=20, special_cases=[]) + + # Test for "when" selecting the "raise_error()" expression (null-type). assert_gpu_and_cpu_are_equal_collect( lambda spark: unary_op_df(spark, data_gen, num_slices=2).select( - f.when(f.col('a') > 30, f.raise_error("unexpected")))) + f.when(f.col('a') > 30, f.raise_error("unexpected"))), + conf=test_conf) + + # Test for if/else, with raise_error in the else. + # This should test if the data-type of raise_error() interferes with + # the result-type of the parent expression (if/else). + assert_gpu_and_cpu_are_equal_sql( + lambda spark: unary_op_df(spark, data_gen, num_slices=2), + 'test_table', + """ + SELECT IF( a < 30, a, raise_error('unexpected') ) + FROM test_table + """, + conf=test_conf) assert_gpu_and_cpu_are_equal_collect( - lambda spark: spark.range(0).select(f.raise_error(f.col("id")))) + lambda spark: spark.range(0).select(f.raise_error(f.col("id"))), + conf=test_conf) + error_fragment = "org.apache.spark.SparkRuntimeException" if use_new_error_semantics \ + else "java.lang.RuntimeException" assert_gpu_and_cpu_error( lambda spark: unary_op_df(spark, null_gen, length=2, num_slices=1).select( f.raise_error(f.col('a'))).collect(), - conf={}, - error_message="java.lang.RuntimeException") + conf=test_conf, + error_message=error_fragment) + error_fragment = error_fragment + (": [USER_RAISED_EXCEPTION] unexpected" if use_new_error_semantics + else ": unexpected") assert_gpu_and_cpu_error( lambda spark: unary_op_df(spark, short_gen, length=2, num_slices=1).select( f.raise_error(f.lit("unexpected"))).collect(), - conf={}, - error_message="java.lang.RuntimeException: unexpected") + conf=test_conf, + error_message=error_fragment) + + +def test_raise_error_legacy_semantics(): + """ + Tests the "legacy" semantics of raise_error(), i.e. where the error + does not include an error class. + """ + if is_spark_400_or_later() or is_databricks_version_or_later(14, 3): + # Spark 4+ and Databricks 14.3+ support RaiseError with error-classes included. + # Must test "legacy" mode, where error-classes are excluded. + raise_error_test_impl(test_conf={legacy_semantics_key: True}) + else: + # Spark versions preceding 4.0, or Databricks 14.3 do not support RaiseError with + # error-classes. No legacy mode need be selected. + raise_error_test_impl(test_conf={}) + + +@pytest.mark.skipif(condition=not (is_spark_400_or_later() or is_databricks_version_or_later(14, 3)), + reason="RaiseError semantics with error-classes are only supported " + "on Spark 4.0+ and Databricks 14.3+.") +def test_raise_error_new_semantics(): + """ + Tests the "new" semantics of raise_error(), i.e. where the error + includes an error class. Unsupported in Spark versions predating + Spark 4.0, Databricks 14.3. + """ + raise_error_test_impl(test_conf={legacy_semantics_key: False}) \ No newline at end of file diff --git a/sql-plugin/src/main/spark350db143/scala/com/nvidia/spark/rapids/shims/RaiseErrorShim.scala b/sql-plugin/src/main/spark350db143/scala/com/nvidia/spark/rapids/shims/RaiseErrorShim.scala index 2d1f011b9d9..3ad318c3441 100644 --- a/sql-plugin/src/main/spark350db143/scala/com/nvidia/spark/rapids/shims/RaiseErrorShim.scala +++ b/sql-plugin/src/main/spark350db143/scala/com/nvidia/spark/rapids/shims/RaiseErrorShim.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2024, NVIDIA CORPORATION. + * Copyright (c) 2024-2025, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -19,10 +19,24 @@ spark-rapids-shim-json-lines ***/ package com.nvidia.spark.rapids.shims -import com.nvidia.spark.rapids.ExprRule +import com.nvidia.spark.rapids.{ExprRule, GpuOverrides} +import com.nvidia.spark.rapids.{ExprChecks, TypeEnum, TypeSig} -import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.catalyst.expressions.{Expression, RaiseError} +import org.apache.spark.sql.rapids.shims.RaiseErrorMeta object RaiseErrorShim { - val exprs: Map[Class[_ <: Expression], ExprRule[_ <: Expression]] = Map.empty + val exprs: Map[Class[_ <: Expression], ExprRule[_ <: Expression]] = { + Seq(GpuOverrides.expr[RaiseError]( + "Throw an exception", + ExprChecks.binaryProject( + TypeSig.NULL, TypeSig.NULL, + // In Databricks 14.3 and Spark 4.0, RaiseError forwards the lhs expression + // (i.e. the error-class) as a scalar value. A vector/column here would be surprising. + ("errorClass", TypeSig.lit(TypeEnum.STRING), TypeSig.STRING), + ("errorParams", TypeSig.MAP.nested(TypeSig.STRING), TypeSig.MAP.nested(TypeSig.STRING)), + ), + (a, conf, p, r) => new RaiseErrorMeta(a, conf, p, r) + )).map(r => (r.getClassFor.asSubclass(classOf[Expression]), r)).toMap + } } diff --git a/sql-plugin/src/main/spark350db143/scala/org/apache/spark/sql/rapids/shims/misc.scala b/sql-plugin/src/main/spark350db143/scala/org/apache/spark/sql/rapids/shims/misc.scala new file mode 100644 index 00000000000..ff580829ee8 --- /dev/null +++ b/sql-plugin/src/main/spark350db143/scala/org/apache/spark/sql/rapids/shims/misc.scala @@ -0,0 +1,141 @@ +/* + * Copyright (c) 2025, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +/*** spark-rapids-shim-json-lines +{"spark": "350db143"} +{"spark": "400"} +spark-rapids-shim-json-lines ***/ +package org.apache.spark.sql.rapids.shims + +import ai.rapids.cudf.{ColumnVector, ColumnView, Scalar} +import com.nvidia.spark.rapids.{BinaryExprMeta, DataFromReplacementRule, GpuBinaryExpression, GpuColumnVector, GpuExpression, GpuMapUtils, GpuScalar, RapidsConf, RapidsMeta} +import com.nvidia.spark.rapids.Arm.withResource + +import org.apache.spark.sql.catalyst.expressions.{ExpectsInputTypes, Expression, RaiseError} +import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, MapData} +import org.apache.spark.sql.errors.QueryExecutionErrors.raiseError +import org.apache.spark.sql.types.{AbstractDataType, DataType, NullType, StringType} +import org.apache.spark.unsafe.types.UTF8String + +/** + * Implements `raise_error()` for Databricks 14.3 and Spark 4.0. + * Note that while the arity `raise_error()` remains 1 for all user-facing APIs + * (SQL, Scala, Python). But internally, the implementation uses a binary expression, + * where the first argument indicates the "error-class" for the error being raised. + */ +case class GpuRaiseError(left: Expression, right: Expression, dataType: DataType) + extends GpuBinaryExpression with ExpectsInputTypes { + + val errorClass: Expression = left + val errorParams: Expression = right + + override def inputTypes: Seq[AbstractDataType] = Seq(StringType) + override def toString: String = s"raise_error($errorClass, $errorParams)" + + /** Could evaluating this expression cause side-effects, such as throwing an exception? */ + override def hasSideEffects: Boolean = true + + override def doColumnar(lhs: GpuColumnVector, rhs: GpuColumnVector): ColumnVector = + throw new UnsupportedOperationException("Expected errorClass (lhs) to be a String literal") + + override def doColumnar(lhs: GpuColumnVector, rhs: GpuScalar): ColumnVector = + throw new UnsupportedOperationException("Expected errorClass (lhs) to be a String literal") + + private def extractScalaUTF8String(stringScalar: Scalar): UTF8String = { + // This is guaranteed to be a string scalar. + GpuScalar.extract(stringScalar).asInstanceOf[UTF8String] + } + + private def extractStrings(stringsColumn: ColumnView): Array[UTF8String] = { + val size = stringsColumn.getRowCount.asInstanceOf[Int] // Already checked if exceeds threshold. + val output: Array[UTF8String] = new Array[UTF8String](size) + for (i <- 0 until size) { + output(i) = withResource(stringsColumn.getScalarElement(i)) { + extractScalaUTF8String(_) + } + } + output + } + + private def makeMapData(listOfStructs: ColumnView): MapData = { + val THRESHOLD: Int = 10 // Avoiding surprises with large maps. + // All testing indicates a map with 1 entry. + val mapSize = listOfStructs.getRowCount + + if (mapSize > THRESHOLD) { + throw new UnsupportedOperationException("Unexpectedly large error-parameter map") + } + + val outputKeys: Array[UTF8String] = + withResource(GpuMapUtils.getKeysAsListView(listOfStructs)) { listOfKeys => + withResource(listOfKeys.getChildColumnView(0)) { // Strings child of LIST column. + extractStrings(_) + } + } + + val outputVals: Array[UTF8String] = + withResource(GpuMapUtils.getValuesAsListView(listOfStructs)) { listOfVals => + withResource(listOfVals.getChildColumnView(0)) { // Strings child of LIST column. + extractStrings(_) + } + } + + ArrayBasedMapData(outputKeys, outputVals) + } + + override def doColumnar(lhs: GpuScalar, rhs: GpuColumnVector): ColumnVector = { + if (rhs.getRowCount <= 0) { + // For the case: when(condition, raise_error(col("a")) + // When `condition` selects no rows, a vector of nulls should be returned, + // instead of throwing. + return GpuColumnVector.columnVectorFromNull(0, NullType) + } + + val lhsErrorClass = lhs.getValue.asInstanceOf[UTF8String] + + val rhsMapData = withResource(rhs.getBase.slice(0,1)) { slices => + val firstRhsRow = slices(0) + makeMapData(firstRhsRow) + } + + throw raiseError(lhsErrorClass, rhsMapData) + } + + override def doColumnar(numRows: Int, lhs: GpuScalar, rhs: GpuScalar): ColumnVector = { + if (numRows <= 0) { + // For the case: when(condition, raise_error(col("a")) + // When `condition` selects no rows, a vector of nulls should be returned, + // instead of throwing. + return GpuColumnVector.columnVectorFromNull(0, NullType) + } + + val errorClass = lhs.getValue.asInstanceOf[UTF8String] + // TODO (future): Check if the map-data needs to be extracted differently. + // All testing indicates that the host value of the map literal is set always pre-set. + // But if it isn't, then GpuScalar.getValue might extract it incorrectly. + // https://github.com/NVIDIA/spark-rapids/issues/11974 + val errorParams = rhs.getValue.asInstanceOf[MapData] + throw raiseError(errorClass, errorParams) + } +} + +class RaiseErrorMeta(r: RaiseError, + conf: RapidsConf, + parent: Option[RapidsMeta[_, _, _]], + rule: DataFromReplacementRule ) + extends BinaryExprMeta[RaiseError](r, conf, parent, rule) { + override def convertToGpu(lhsErrorClass: Expression, rhsErrorParams: Expression): GpuExpression + = GpuRaiseError(lhsErrorClass, rhsErrorParams, r.dataType) +} \ No newline at end of file diff --git a/tools/generated_files/400/operatorsScore.csv b/tools/generated_files/400/operatorsScore.csv index 0a099fc2233..a3b77fffb82 100644 --- a/tools/generated_files/400/operatorsScore.csv +++ b/tools/generated_files/400/operatorsScore.csv @@ -218,6 +218,7 @@ PythonUDAF,4 PythonUDF,4 Quarter,4 RLike,4 +RaiseError,4 Rand,4 Rank,4 RegExpExtract,4 diff --git a/tools/generated_files/400/supportedExprs.csv b/tools/generated_files/400/supportedExprs.csv index 926bb4f6c36..2e8c37f7f45 100644 --- a/tools/generated_files/400/supportedExprs.csv +++ b/tools/generated_files/400/supportedExprs.csv @@ -461,6 +461,9 @@ Quarter,S,`quarter`,None,project,result,NA,NA,NA,S,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA RLike,S,`regexp_like`; `regexp`; `rlike`,None,project,str,NA,NA,NA,NA,NA,NA,NA,NA,NA,S,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA RLike,S,`regexp_like`; `regexp`; `rlike`,None,project,regexp,NA,NA,NA,NA,NA,NA,NA,NA,NA,PS,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA RLike,S,`regexp_like`; `regexp`; `rlike`,None,project,result,S,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA +RaiseError,S, ,None,project,errorClass,NA,NA,NA,NA,NA,NA,NA,NA,NA,PS,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA +RaiseError,S, ,None,project,errorParams,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,S,NA,NA,NA,NA +RaiseError,S, ,None,project,result,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,S,NA,NA,NA,NA,NA,NA,NA,NA Rand,S,`rand`; `random`,None,project,seed,NA,NA,NA,S,S,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA Rand,S,`rand`; `random`,None,project,result,NA,NA,NA,NA,NA,NA,S,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA Rank,S,`rank`,None,window,ordering,S,S,S,S,S,S,S,S,PS,S,S,S,NS,NS,NS,NS,NS,NS,NS,NS