diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index a87ba8a865d36..a9fbe548ba39e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.catalyst.analysis import scala.collection.mutable -import org.apache.spark.SparkException +import org.apache.spark.{SparkException, SparkThrowable} import org.apache.spark.internal.{Logging, LogKeys, MDC} import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.ExtendedAnalysisException @@ -57,10 +57,10 @@ trait CheckAnalysis extends PredicateHelper with LookupCatalog with QueryErrorsB val DATA_TYPE_MISMATCH_ERROR = TreeNodeTag[Unit]("dataTypeMismatchError") val INVALID_FORMAT_ERROR = TreeNodeTag[Unit]("invalidFormatError") - // Error that has a lower priority, that are not supposed to throw immediately on triggering, - // e.g. certain internal errors. These errors will be thrown at the end of the whole check - // analysis process, if no other error occurs. - var preemptedError: Option[SparkException] = None + // Error that is not supposed to throw immediately on triggering, e.g. certain internal errors. + // The error will be thrown at the end of the whole check analysis process, if no other error + // occurs. + val preemptedError = new PreemptedError() /** * Fails the analysis at the point where a specific tree node was parsed using a provided @@ -119,17 +119,14 @@ trait CheckAnalysis extends PredicateHelper with LookupCatalog with QueryErrorsB private def checkNotContainingLCA(exprs: Seq[Expression], plan: LogicalPlan): Unit = { exprs.foreach(_.transformDownWithPruning(_.containsPattern(LATERAL_COLUMN_ALIAS_REFERENCE)) { case lcaRef: LateralColumnAliasReference => - // this should be a low priority internal error - if (preemptedError.isEmpty) { - preemptedError = Some( - SparkException.internalError( - "Resolved plan should not contain any " + - s"LateralColumnAliasReference.\nDebugging information: plan:\n$plan", - context = lcaRef.origin.getQueryContext, - summary = lcaRef.origin.context.summary - ) - ) - } + // this should be a low priority internal error to be preempted + preemptedError.set( + SparkException.internalError( + "Resolved plan should not contain any " + + s"LateralColumnAliasReference.\nDebugging information: plan:\n$plan", + context = lcaRef.origin.getQueryContext, + summary = lcaRef.origin.context.summary) + ) lcaRef }) } @@ -187,15 +184,15 @@ trait CheckAnalysis extends PredicateHelper with LookupCatalog with QueryErrorsB case e: AnalysisException => throw new ExtendedAnalysisException(e, plan) } - preemptedError = None + preemptedError.clear() try { checkAnalysis0(inlinedPlan) - preemptedError.foreach(throw _) // throw preempted error if any + preemptedError.getErrorOpt().foreach(throw _) // throw preempted error if any } catch { case e: AnalysisException => throw new ExtendedAnalysisException(e, inlinedPlan) } finally { - preemptedError = None + preemptedError.clear() } plan.setAnalyzed() } @@ -1563,3 +1560,31 @@ trait CheckAnalysis extends PredicateHelper with LookupCatalog with QueryErrorsB } } } + +// a heap of the preempted error that only keeps the top priority element, representing the sole +// error to be thrown at the end of the whole check analysis process, if no other error occurs. +class PreemptedError() { + case class ErrorWithPriority(error: Exception with SparkThrowable, priority: Int) {} + + private var errorOpt: Option[ErrorWithPriority] = None + + // Set/overwrite the given error as the preempted error, if no other errors are preempted, or it + // has a higher priority than the existing one. + // If the priority is not provided, it will be calculated based on error class. Currently internal + // errors have the lowest priority. + def set(error: Exception with SparkThrowable, priority: Option[Int] = None): Unit = { + val calculatedPriority = priority.getOrElse { + error.getErrorClass match { + case c if c.startsWith("INTERNAL_ERROR") => 1 + case _ => 2 + } + } + if (errorOpt.isEmpty || calculatedPriority > errorOpt.get.priority) { + errorOpt = Some(ErrorWithPriority(error, calculatedPriority)) + } + } + + def getErrorOpt(): Option[Exception with SparkThrowable] = errorOpt.map(_.error) + + def clear(): Unit = errorOpt = None +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala index 62856a96f7ee8..a44b52d0bbc28 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala @@ -43,6 +43,7 @@ import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, Partition import org.apache.spark.sql.catalyst.types.DataTypeUtils import org.apache.spark.sql.catalyst.util._ import org.apache.spark.sql.connector.catalog.InMemoryTable +import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ @@ -1807,4 +1808,26 @@ class AnalysisSuite extends AnalysisTest with Matchers { val plan = testRelation.select(udf.as("u")).select($"u").analyze assert(plan.output.head.nullable) } + + test("test methods of PreemptedError") { + val preemptedError = new PreemptedError() + assert(preemptedError.getErrorOpt().isEmpty) + + val internalError = SparkException.internalError("some internal error to be preempted") + preemptedError.set(internalError) + assert(preemptedError.getErrorOpt().contains(internalError)) + + // set error with higher priority will overwrite + val regularError = QueryCompilationErrors.unresolvedColumnError("name", Seq("a")) + .asInstanceOf[AnalysisException] + preemptedError.set(regularError) + assert(preemptedError.getErrorOpt().contains(regularError)) + + // set error with lower priority is noop + preemptedError.set(internalError) + assert(preemptedError.getErrorOpt().contains(regularError)) + + preemptedError.clear() + assert(preemptedError.getErrorOpt().isEmpty) + } }