Skip to content

Commit

Permalink
Support raise_error() on Databricks 14.3, Spark 4.
Browse files Browse the repository at this point in the history
Fixes NVIDIA#10969.

This commit adds support for `raise_error()` on Databricks 14.3 and
Spark 4.0.

On these new Spark versions, the `RaiseError` expression (that powers
the `raise_error()` API function) was changed from a Unary expression to
a Binary one.  This was done without modifying the arity of
`raise_error()`.  The ostensible reason seems to have been to eventually allow
user-code to raise custom errors via `raise_error()`.

This commit allows `raise_error()` to work on the GPU as it currently
does on the CPU:  as a unary function powered by a binary expression in
the background.

The tests have been modified to verify both the new behaviour and the
legacy one on new platforms, while continuing to run as before on legacy
platforms.

Signed-off-by: MithunR <mithunr@nvidia.com>
  • Loading branch information
mythrocks committed Jan 16, 2025
1 parent 50b14de commit e45d5e9
Show file tree
Hide file tree
Showing 3 changed files with 194 additions and 15 deletions.
53 changes: 42 additions & 11 deletions integration_tests/src/main/python/misc_expr_test.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2020-2024, NVIDIA CORPORATION.
# Copyright (c) 2020-2025, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -33,27 +33,58 @@ def test_part_id():
f.col('a'),
f.spark_partition_id()))

# Spark conf key for choosing legacy error semantics.
legacy_semantics_key = "spark.sql.legacy.raiseErrorWithoutErrorClass"
is_new_raise_error_semantics_version=is_spark_400_or_later() or is_databricks_version_or_later(14, 3)

def raise_error_test_impl(test_conf):
use_new_error_semantics = legacy_semantics_key in test_conf and test_conf[legacy_semantics_key] == False

@pytest.mark.skipif(condition=is_spark_400_or_later() or is_databricks_version_or_later(14, 3),
reason="raise_error() not currently implemented for Spark 4.0, or Databricks 14.3. "
"See https://github.com/NVIDIA/spark-rapids/issues/10107.")
def test_raise_error():
data_gen = ShortGen(nullable=False, min_val=0, max_val=20, special_cases=[])
assert_gpu_and_cpu_are_equal_collect(
lambda spark: unary_op_df(spark, data_gen, num_slices=2).select(
f.when(f.col('a') > 30, f.raise_error("unexpected"))))
f.when(f.col('a') > 30, f.raise_error("unexpected"))),
conf=test_conf)

assert_gpu_and_cpu_are_equal_collect(
lambda spark: spark.range(0).select(f.raise_error(f.col("id"))))
lambda spark: spark.range(0).select(f.raise_error(f.col("id"))),
conf=test_conf)

error_fragment = "org.apache.spark.SparkRuntimeException" if use_new_error_semantics \
else "java.lang.RuntimeException"
assert_gpu_and_cpu_error(
lambda spark: unary_op_df(spark, null_gen, length=2, num_slices=1).select(
f.raise_error(f.col('a'))).collect(),
conf={},
error_message="java.lang.RuntimeException")
conf=test_conf,
error_message=error_fragment)

error_fragment = error_fragment + (": [USER_RAISED_EXCEPTION] unexpected" if use_new_error_semantics
else ": unexpected")
assert_gpu_and_cpu_error(
lambda spark: unary_op_df(spark, short_gen, length=2, num_slices=1).select(
f.raise_error(f.lit("unexpected"))).collect(),
conf={},
error_message="java.lang.RuntimeException: unexpected")
conf=test_conf,
error_message=error_fragment)


def test_raise_error_legacy_semantics():
"""
Tests the "legacy" semantics of raise_error(), i.e. where the error
does not include an error class.
"""
if is_new_raise_error_semantics_version:
raise_error_test_impl(test_conf={legacy_semantics_key: True})
else:
raise_error_test_impl(test_conf={})


@pytest.mark.skipif(condition=not is_new_raise_error_semantics_version,
reason="New raise_error semantics (with error-class) is only available "
"on Spark 4.0 and Databricks 14.3.")
def test_raise_error_new_semantics():
"""
Tests the "new" semantics of raise_error(), i.e. where the error
includes an error class. Unsupported in Spark versions predating
Spark 4.0, Databricks 14.3.
"""
raise_error_test_impl(test_conf={legacy_semantics_key: False})
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2024, NVIDIA CORPORATION.
* Copyright (c) 2024-2025, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -19,10 +19,34 @@
spark-rapids-shim-json-lines ***/
package com.nvidia.spark.rapids.shims

import com.nvidia.spark.rapids.ExprRule
import com.nvidia.spark.rapids.{ExprRule, GpuOverrides}
import com.nvidia.spark.rapids.{ExprChecks, GpuExpression, TypeSig, BinaryExprMeta}

import org.apache.spark.sql.catalyst.expressions.Expression
import org.apache.spark.sql.catalyst.expressions.{Expression, Literal, RaiseError}
import org.apache.spark.sql.rapids.shims.GpuRaiseError

object RaiseErrorShim {
val exprs: Map[Class[_ <: Expression], ExprRule[_ <: Expression]] = Map.empty
val exprs: Map[Class[_ <: Expression], ExprRule[_ <: Expression]] = {
Seq(GpuOverrides.expr[RaiseError](
"Throw an exception",
ExprChecks.binaryProject(
TypeSig.NULL, TypeSig.NULL,
("errorClass", TypeSig.STRING, TypeSig.STRING),
("errorParams", TypeSig.MAP.nested(TypeSig.STRING), TypeSig.MAP.nested(TypeSig.STRING)),
),
(a, conf, p, r) => new BinaryExprMeta[RaiseError](a, conf, p, r) {

override def tagExprForGpu(): Unit = {
// In Databricks 14.3 and Spark 4.0, RaiseError forwards the lhs expression (i.e. the error-class)
// as a scalar value. A vector/column here would be surprising.
a.errorClass match {
case _: Literal => // Supported.
case _ => willNotWorkOnGpu(s"expected error-class to be a STRING literal")
}
}

override def convertToGpu(lhsErrorClass: Expression, rhsErrorParams: Expression): GpuExpression =
GpuRaiseError(lhsErrorClass, rhsErrorParams)
})).map(r => (r.getClassFor.asSubclass(classOf[Expression]), r)).toMap
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
/*
* Copyright (c) 2025, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*** spark-rapids-shim-json-lines
{"spark": "350db143"}
{"spark": "400"}
spark-rapids-shim-json-lines ***/
package org.apache.spark.sql.rapids.shims

import ai.rapids.cudf.{ColumnVector, ColumnView, Scalar}
import com.nvidia.spark.rapids.{GpuColumnVector, GpuBinaryExpression, GpuMapUtils, GpuScalar}
import com.nvidia.spark.rapids.Arm.withResource

import org.apache.spark.sql.catalyst.expressions.{ExpectsInputTypes, Expression}
import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, MapData}
import org.apache.spark.sql.errors.QueryExecutionErrors.raiseError
import org.apache.spark.sql.types.{AbstractDataType, DataType, NullType, StringType}
import org.apache.spark.unsafe.types.UTF8String

/**
* Implements `raise_error()` for Databricks 14.3 and Spark 4.0.
* Note that while the arity `raise_error()` remains 1 for all user-facing APIs (SQL, Scala, Python).
* But internally, the implementation uses a binary expression, where the first argument indicates
* the "error-class" for the error being raised.
*/
case class GpuRaiseError(left: Expression, right: Expression) extends GpuBinaryExpression with ExpectsInputTypes {

val errorClass: Expression = left
val errorParams: Expression = right

override def dataType: DataType = NullType
override def inputTypes: Seq[AbstractDataType] = Seq(StringType)
override def toString: String = s"raise_error($errorClass, $errorParams)"

/** Could evaluating this expression cause side-effects, such as throwing an exception? */
override def hasSideEffects: Boolean = true

override def doColumnar(lhs: GpuColumnVector, rhs: GpuColumnVector): ColumnVector =
throw new UnsupportedOperationException("Expected errorClass (lhs) to be a String literal")

override def doColumnar(lhs: GpuColumnVector, rhs: GpuScalar): ColumnVector =
throw new UnsupportedOperationException("Expected errorClass (lhs) to be a String literal")

private def extractScalaUTF8String(stringScalar: Scalar): UTF8String = {
// This is guaranteed to be a string scalar.
GpuScalar.extract(stringScalar).asInstanceOf[UTF8String]
}

private def extractStrings(stringsColumn: ColumnView): Array[UTF8String] = {
val size = stringsColumn.getRowCount.asInstanceOf[Int] // Already checked if exceeds threshold.
val output: Array[UTF8String] = new Array[UTF8String](size)
for (i <- 0 until size) {
output(i) = withResource(stringsColumn.getScalarElement(i)) {
extractScalaUTF8String(_)
}
}
output
}

private def makeMapData(listOfStructs: ColumnView): MapData = {
val THRESHOLD: Int = 10 // Avoiding surprises with large maps. All testing indicates a map with 1 entry.
val mapSize = listOfStructs.getRowCount

if (mapSize > THRESHOLD)
throw new UnsupportedOperationException("Unexpectedly large error-parameter map")

val outputKeys: Array[UTF8String] =
withResource(GpuMapUtils.getKeysAsListView(listOfStructs)) { listOfKeys =>
withResource(listOfKeys.getChildColumnView(0)) { // Strings child of LIST column.
extractStrings(_)
}
}

val outputVals: Array[UTF8String] =
withResource(GpuMapUtils.getValuesAsListView(listOfStructs)) { listOfVals =>
withResource(listOfVals.getChildColumnView(0)) { // Strings child of LIST column.
extractStrings(_)
}
}

ArrayBasedMapData(outputKeys, outputVals)
}

override def doColumnar(lhs: GpuScalar, rhs: GpuColumnVector): ColumnVector = {
if (rhs.getRowCount <= 0) {
// For the case: when(condition, raise_error(col("a"))
// When `condition` selects no rows, a vector of nulls should be returned, instead of throwing.
return GpuColumnVector.columnVectorFromNull(0, NullType)
}

val lhsErrorClass = lhs.getValue.asInstanceOf[UTF8String]

val rhsMapData = withResource(rhs.getBase.slice(0,1)) { slices =>
val firstRhsRow = slices(0)
makeMapData(firstRhsRow)
}

throw raiseError(lhsErrorClass, rhsMapData)
}

override def doColumnar(numRows: Int, lhs: GpuScalar, rhs: GpuScalar): ColumnVector = {
if (numRows <= 0) {
// For the case: when(condition, raise_error(col("a"))
// When `condition` selects no rows, a vector of nulls should be returned, instead of throwing.
return GpuColumnVector.columnVectorFromNull(0, NullType)
}

val errorClass = lhs.getValue.asInstanceOf[UTF8String]
val errorParams = rhs.getValue.asInstanceOf[MapData]
throw raiseError(errorClass, errorParams)
}
}

0 comments on commit e45d5e9

Please sign in to comment.