Skip to content

Commit

Permalink
[SPARK-44838][SQL] raise_error improvement
Browse files Browse the repository at this point in the history
### What changes were proposed in this pull request?

Extend the raise_error() function to a two-argument version:
raise_error(errorClassStr, errorParamMap)
This new form will accept any error class defined in error-classes.json and require Map<String, String> to provide values for the parameters in the error classes template.
Externally an error raised via raise_error() is indistinguishable from an error raised from within the Spark engine.

The single-parameter raise_error(str) will raise USER_RAISED_EXCEPTION (SQLSTATE P0001 - borrowed from PostgreSQL).
USER_RAISED_EXCEPTION text is: "<errorMessage>" which will be filled in with the str - value.

We will also provide `spark.sql.legacy.raiseErrorWithoutErrorClass` (default: false) to revert to the old behavior for the single-parameter version.

Naturally assert_true() will also return `USER_RAISED_EXCEPTION`.

#### Examples
```
SELECT raise_error('VIEW_NOT_FOUND', map('relationName', '`v1`');
  [VIEW_NOT_FOUND] The view `v1` cannot be found. Verify the spelling ...

SELECT raise_error('Error!');
  [USER_RAISED_EXCEPTION] Error!

SELECT assert_true(1 < 0);
 [USER_RAISED_EXCEPTION] '(1 < 0)' is not true!

SELECT assert_true(1 < 0, 'bad!')
  [USER_RAISED_EXCEPTION] bad!
```

### Why are the changes needed?

This change moves raise_error() and assert_true() to the new error frame work.
It greatly expands the ability of users to raise error messages which can be intercepted via SQLSTATE and/or error class.

### Does this PR introduce _any_ user-facing change?

Yes, the result of assert_true() changes and raise_error() gains a new signature.

### How was this patch tested?

Run existing QA and add new tests for assert_true and raise_error

### Was this patch authored or co-authored using generative AI tooling?

No

Closes #42985 from srielau/SPARK-44838-raise_error.

Lead-authored-by: srielau <serge@rielau.com>
Co-authored-by: Serge Rielau <srielau@users.noreply.github.com>
Co-authored-by: Wenchen Fan <wenchen@databricks.com>
Signed-off-by: Gengliang Wang <gengliang@apache.org>
  • Loading branch information
3 people authored and gengliangwang committed Sep 27, 2023
1 parent 8399dd3 commit 9109d70
Show file tree
Hide file tree
Showing 23 changed files with 551 additions and 67 deletions.
26 changes: 26 additions & 0 deletions common/utils/src/main/resources/error/error-classes.json
Original file line number Diff line number Diff line change
Expand Up @@ -3502,6 +3502,26 @@
"3. set \"spark.sql.legacy.allowUntypedScalaUDF\" to \"true\" and use this API with caution."
]
},
"USER_RAISED_EXCEPTION" : {
"message" : [
"<errorMessage>"
],
"sqlState" : "P0001"
},
"USER_RAISED_EXCEPTION_PARAMETER_MISMATCH" : {
"message" : [
"The `raise_error()` function was used to raise error class: <errorClass> which expects parameters: <expectedParms>.",
"The provided parameters <providedParms> do not match the expected parameters.",
"Please make sure to provide all expected parameters."
],
"sqlState" : "P0001"
},
"USER_RAISED_EXCEPTION_UNKNOWN_ERROR_CLASS" : {
"message" : [
"The `raise_error()` function was used to raise an unknown error class: <errorClass>"
],
"sqlState" : "P0001"
},
"VARIABLE_ALREADY_EXISTS" : {
"message" : [
"Cannot create the variable <variableName> because it already exists.",
Expand Down Expand Up @@ -6310,5 +6330,11 @@
"message" : [
"Failed to get block <blockId>, which is not a shuffle block"
]
},
"_LEGACY_ERROR_USER_RAISED_EXCEPTION" : {
"message" : [
"<errorMessage>"
],
"sqlState" : "P0001"
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,13 @@ class ErrorClassesJsonReader(jsonFileURLs: Seq[URL]) {
}
}

def getMessageParameters(errorClass: String): Seq[String] = {
val messageTemplate = getMessageTemplate(errorClass)
val pattern = "<([a-zA-Z0-9_-]+)>".r
val matches = pattern.findAllIn(messageTemplate).toSeq
matches.map(m => m.stripSuffix(">").stripPrefix("<"))
}

def getMessageTemplate(errorClass: String): String = {
val errorClasses = errorClass.split("\\.")
assert(errorClasses.length == 1 || errorClasses.length == 2)
Expand Down Expand Up @@ -85,6 +92,17 @@ class ErrorClassesJsonReader(jsonFileURLs: Seq[URL]) {
.flatMap(_.sqlState)
.orNull
}

def isValidErrorClass(errorClass: String): Boolean = {
val errorClasses = errorClass.split("\\.")
errorClasses match {
case Array(mainClass) => errorInfoMap.contains(mainClass)
case Array(mainClass, subClass) => errorInfoMap.get(mainClass).map { info =>
info.subClass.get.contains(subClass)
}.getOrElse(false)
case _ => false
}
}
}

private object ErrorClassesJsonReader {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,14 +52,22 @@ private[spark] object SparkThrowableHelper {
context: String): String = {
val displayMessage = errorReader.getErrorMessage(errorClass, messageParameters)
val displayQueryContext = (if (context.isEmpty) "" else "\n") + context
val prefix = if (errorClass.startsWith("_LEGACY_ERROR_TEMP_")) "" else s"[$errorClass] "
val prefix = if (errorClass.startsWith("_LEGACY_ERROR_")) "" else s"[$errorClass] "
s"$prefix$displayMessage$displayQueryContext"
}

def getSqlState(errorClass: String): String = {
errorReader.getSqlState(errorClass)
}

def isValidErrorClass(errorClass: String): Boolean = {
errorReader.isValidErrorClass(errorClass)
}

def getMessageParameters(errorClass: String): Seq[String] = {
errorReader.getMessageParameters(errorClass)
}

def isInternalError(errorClass: String): Boolean = {
errorClass.startsWith("INTERNAL_ERROR")
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3305,6 +3305,14 @@ object functions {
*/
def raise_error(c: Column): Column = Column.fn("raise_error", c)

/**
* Throws an exception with the provided error message.
*
* @group misc_funcs
* @since 4.0.0
*/
def raise_error(c: Column, e: Column): Column = Column.fn("raise_error", c, e)

/**
* Returns the estimated number of unique values given the binary representation of a
* Datasketches HllSketch.
Expand Down
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
Project [if ((id#0L > cast(0 as bigint))) null else raise_error(id negative!, NullType) AS assert_true((id > 0), id negative!)#0]
Project [if ((id#0L > cast(0 as bigint))) null else raise_error(USER_RAISED_EXCEPTION, map(errorMessage, id negative!), NullType) AS assert_true((id > 0), id negative!)#0]
+- LocalRelation <empty>, [id#0L, a#0, b#0, d#0, e#0, f#0, g#0]
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
Project [raise_error(kaboom, NullType) AS raise_error(kaboom)#0]
Project [raise_error(USER_RAISED_EXCEPTION, map(errorMessage, kaboom), NullType) AS raise_error(USER_RAISED_EXCEPTION, map(errorMessage, kaboom))#0]
+- LocalRelation <empty>, [id#0L, a#0, b#0, d#0, e#0, f#0, g#0]
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ class SparkThrowableSuite extends SparkFunSuite {

test("Message format invariants") {
val messageFormats = errorReader.errorInfoMap
.filterKeys(!_.startsWith("_LEGACY_ERROR_TEMP_"))
.filterKeys(!_.startsWith("_LEGACY_ERROR_"))
.filterKeys(!_.startsWith("INTERNAL_ERROR"))
.values.toSeq.flatMap { i => Seq(i.messageTemplate) }
checkCondition(messageFormats, s => s != null)
Expand Down Expand Up @@ -236,7 +236,7 @@ class SparkThrowableSuite extends SparkFunSuite {
orphans
}

val sqlErrorParentDocContent = errors.toSeq.filter(!_._1.startsWith("_LEGACY_ERROR_TEMP_"))
val sqlErrorParentDocContent = errors.toSeq.filter(!_._1.startsWith("_LEGACY_ERROR"))
.sortBy(_._1).map(error => {
val name = error._1
val info = error._2
Expand Down
20 changes: 20 additions & 0 deletions docs/sql-error-conditions.md
Original file line number Diff line number Diff line change
Expand Up @@ -2149,6 +2149,26 @@ You're using untyped Scala UDF, which does not have the input type information.
2. use Java UDF APIs, e.g. `udf(new UDF1[String, Integer] { override def call(s: String): Integer = s.length() }, IntegerType)`, if input types are all non primitive.
3. set "spark.sql.legacy.allowUntypedScalaUDF" to "true" and use this API with caution.

### USER_RAISED_EXCEPTION

SQLSTATE: P0001

`<errorMessage>`

### USER_RAISED_EXCEPTION_PARAMETER_MISMATCH

SQLSTATE: P0001

The `raise_error()` function was used to raise error class: `<errorClass>` which expects parameters: `<expectedParms>`.
The provided parameters `<providedParms>` do not match the expected parameters.
Please make sure to provide all expected parameters.

### USER_RAISED_EXCEPTION_UNKNOWN_ERROR_CLASS

SQLSTATE: P0001

The `raise_error()` function was used to raise an unknown error class: `<errorClass>`

### VARIABLE_ALREADY_EXISTS

[SQLSTATE: 42723](sql-error-conditions-sqlstates.html#class-42-syntax-error-or-access-rule-violation)
Expand Down
4 changes: 2 additions & 2 deletions python/pyspark/sql/tests/test_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -1031,10 +1031,10 @@ def check_assert_true(self, tpe):
[Row(val=None), Row(val=None), Row(val=None)],
)

with self.assertRaisesRegex(tpe, "too big"):
with self.assertRaisesRegex(tpe, r"\[USER_RAISED_EXCEPTION\] too big"):
df.select(F.assert_true(df.id < 2, "too big")).toDF("val").collect()

with self.assertRaisesRegex(tpe, "2000000"):
with self.assertRaisesRegex(tpe, r"\[USER_RAISED_EXCEPTION\] 2000000.0"):
df.select(F.assert_true(df.id < 2, df.id * 1e6)).toDF("val").collect()

with self.assertRaises(PySparkTypeError) as pe:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@ import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
import org.apache.spark.sql.catalyst.expressions.objects.StaticInvoke
import org.apache.spark.sql.catalyst.trees.TreePattern.{CURRENT_LIKE, TreePattern}
import org.apache.spark.sql.catalyst.util.RandomUUIDGenerator
import org.apache.spark.sql.catalyst.util.{MapData, RandomUUIDGenerator}
import org.apache.spark.sql.errors.QueryExecutionErrors.raiseError
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String
Expand Down Expand Up @@ -61,64 +62,88 @@ case class PrintToStderr(child: Expression) extends UnaryExpression {
/**
* Throw with the result of an expression (used for debugging).
*/
// scalastyle:off line.size.limit
@ExpressionDescription(
usage = "_FUNC_(expr) - Throws an exception with `expr`.",
usage = "_FUNC_( expr [, errorParams ]) - Throws a USER_RAISED_EXCEPTION with `expr` as message, or a defined error class in `expr` with a parameter map. A `null` errorParms is equivalent to an empty map.",
examples = """
Examples:
> SELECT _FUNC_('custom error message');
java.lang.RuntimeException
custom error message
[USER_RAISED_EXCEPTION] custom error message
> SELECT _FUNC_('VIEW_NOT_FOUND', Map('relationName' -> '`V1`'));
[VIEW_NOT_FOUND] The view `V1` cannot be found. ...
""",
since = "3.1.0",
group = "misc_funcs")
case class RaiseError(child: Expression, dataType: DataType)
extends UnaryExpression with ImplicitCastInputTypes {
// scalastyle:on line.size.limit
case class RaiseError(errorClass: Expression, errorParms: Expression, dataType: DataType)
extends BinaryExpression with ImplicitCastInputTypes {

def this(str: Expression) = {
this(Literal(
if (SQLConf.get.getConf(SQLConf.LEGACY_RAISE_ERROR_WITHOUT_ERROR_CLASS)) {
"_LEGACY_ERROR_USER_RAISED_EXCEPTION"
} else {
"USER_RAISED_EXCEPTION"
}),
CreateMap(Seq(Literal("errorMessage"), str)), NullType)
}

def this(child: Expression) = this(child, NullType)
def this(errorClass: Expression, errorParms: Expression) = {
this(errorClass, errorParms, NullType)
}

override def foldable: Boolean = false
override def nullable: Boolean = true
override def inputTypes: Seq[AbstractDataType] = Seq(StringType)
override def inputTypes: Seq[AbstractDataType] =
Seq(StringType, MapType(StringType, StringType))

override def left: Expression = errorClass
override def right: Expression = errorParms

override def prettyName: String = "raise_error"

override def eval(input: InternalRow): Any = {
val value = child.eval(input)
if (value == null) {
throw new RuntimeException()
}
throw new RuntimeException(value.toString)
val error = errorClass.eval(input).asInstanceOf[UTF8String]
val parms: MapData = errorParms.eval(input).asInstanceOf[MapData]
throw raiseError(error, parms)
}

// if (true) is to avoid codegen compilation exception that statement is unreachable
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
val eval = child.genCode(ctx)
val error = errorClass.genCode(ctx)
val parms = errorParms.genCode(ctx)
ExprCode(
code = code"""${eval.code}
code = code"""${error.code}
|${parms.code}
|if (true) {
| if (${eval.isNull}) {
| throw new RuntimeException();
| }
| throw new RuntimeException(${eval.value}.toString());
| throw QueryExecutionErrors.raiseError(
| ${error.value},
| ${parms.value});
|}""".stripMargin,
isNull = TrueLiteral,
value = JavaCode.defaultLiteral(dataType)
)
}

override protected def withNewChildInternal(newChild: Expression): RaiseError =
copy(child = newChild)
override protected def withNewChildrenInternal(
newLeft: Expression, newRight: Expression): RaiseError = {
copy(errorClass = newLeft, errorParms = newRight)
}
}

object RaiseError {
def apply(child: Expression): RaiseError = new RaiseError(child)
def apply(str: Expression): RaiseError = new RaiseError(str)

def apply(errorClass: Expression, parms: Expression): RaiseError =
new RaiseError(errorClass, parms)
}

/**
* A function that throws an exception if 'condition' is not true.
*/
@ExpressionDescription(
usage = "_FUNC_(expr) - Throws an exception if `expr` is not true.",
usage = "_FUNC_(expr [, message]) - Throws an exception if `expr` is not true.",
examples = """
Examples:
> SELECT _FUNC_(0 < 1);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import java.io.{File, FileNotFoundException, IOException}
import java.lang.reflect.InvocationTargetException
import java.net.{URISyntaxException, URL}
import java.time.DateTimeException
import java.util.Locale
import java.util.concurrent.TimeoutException

import com.fasterxml.jackson.core.{JsonParser, JsonToken}
Expand All @@ -41,7 +42,7 @@ import org.apache.spark.sql.catalyst.plans.JoinType
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.plans.logical.statsEstimation.ValueInterval
import org.apache.spark.sql.catalyst.trees.{Origin, SQLQueryContext, TreeNode}
import org.apache.spark.sql.catalyst.util.{sideBySide, BadRecordException, DateTimeUtils, FailFastMode}
import org.apache.spark.sql.catalyst.util.{sideBySide, BadRecordException, DateTimeUtils, FailFastMode, MapData}
import org.apache.spark.sql.connector.catalog.{CatalogNotFoundException, Table, TableProvider}
import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._
import org.apache.spark.sql.connector.expressions.Transform
Expand Down Expand Up @@ -2728,4 +2729,50 @@ private[sql] object QueryExecutionErrors extends QueryErrorsBase with ExecutionE
errorClass = "UNSUPPORTED_FEATURE.PURGE_TABLE",
messageParameters = Map.empty)
}

def raiseError(
errorClass: UTF8String,
errorParms: MapData): RuntimeException = {
val errorClassStr = if (errorClass != null) {
errorClass.toString.toUpperCase(Locale.ROOT)
} else {
"null"
}
val errorParmsMap = if (errorParms != null) {
val errorParmsMutable = collection.mutable.Map[String, String]()
errorParms.foreach(StringType, StringType, { case (key, value) =>
errorParmsMutable += (key.toString ->
(if (value == null) { "null" } else { value.toString } ))
})
errorParmsMutable.toMap
} else {
Map.empty[String, String]
}

// Is the error class a known error class? If not raise an error
if (!SparkThrowableHelper.isValidErrorClass(errorClassStr)) {
new SparkRuntimeException(
errorClass = "USER_RAISED_EXCEPTION_UNKNOWN_ERROR_CLASS",
messageParameters = Map("errorClass" -> toSQLValue(errorClassStr)))
} else {
// Did the user provide all parameters? If not raise an error
val expectedParms = SparkThrowableHelper.getMessageParameters(errorClassStr).sorted
val providedParms = errorParmsMap.keys.toSeq.sorted
if (expectedParms != providedParms) {
new SparkRuntimeException(
errorClass = "USER_RAISED_EXCEPTION_PARAMETER_MISMATCH",
messageParameters = Map("errorClass" -> toSQLValue(errorClassStr),
"expectedParms" -> expectedParms.map { p => toSQLValue(p) }.mkString(","),
"providedParms" -> providedParms.map { p => toSQLValue(p) }.mkString(",")))
} else if (errorClassStr == "_LEGACY_ERROR_USER_RAISED_EXCEPTION") {
// Don't break old raise_error() if asked
new RuntimeException(errorParmsMap.head._2)
} else {
// All good, raise the error
new SparkRuntimeException(
errorClass = errorClassStr,
messageParameters = errorParmsMap)
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -4448,6 +4448,17 @@ object SQLConf {
.booleanConf
.createWithDefault(false)

val LEGACY_RAISE_ERROR_WITHOUT_ERROR_CLASS =
buildConf("spark.sql.legacy.raiseErrorWithoutErrorClass")
.internal()
.doc("When set to true, restores the legacy behavior of `raise_error` and `assert_true` to " +
"not return the `[USER_RAISED_EXCEPTION]` prefix." +
"For example, `raise_error('error!')` returns `error!` instead of " +
"`[USER_RAISED_EXCEPTION] Error!`.")
.version("4.0.0")
.booleanConf
.createWithDefault(false)

/**
* Holds information about keys that have been deprecated.
*
Expand Down Expand Up @@ -5317,6 +5328,9 @@ class SQLConf extends Serializable with Logging with SqlApiConf {
getConf(SQLConf.LEGACY_NEGATIVE_INDEX_IN_ARRAY_INSERT)
}

def legacyRaiseErrorWithoutErrorClass: Boolean =
getConf(SQLConf.LEGACY_RAISE_ERROR_WITHOUT_ERROR_CLASS)

/** ********************** SQLConf functionality methods ************ */

/** Set Spark SQL configuration properties. */
Expand Down
Loading

0 comments on commit 9109d70

Please sign in to comment.