diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 02d6778c08e6f..2426a8b4a9062 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -29,6 +29,7 @@ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ +import org.apache.spark.util.Utils /** * Abstract class all optimizers should inherit of, contains the standard batches (extending @@ -37,6 +38,12 @@ import org.apache.spark.sql.types._ abstract class Optimizer(sessionCatalog: SessionCatalog) extends RuleExecutor[LogicalPlan] { + // Check for structural integrity of the plan in test mode. Currently we only check if a plan is + // still resolved after the execution of each rule. + override protected def isPlanIntegral(plan: LogicalPlan): Boolean = { + Utils.isTesting && plan.resolved + } + protected def fixedPoint = FixedPoint(SQLConf.get.optimizerMaxIterations) def batches: Seq[Batch] = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleExecutor.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleExecutor.scala index 0e89d1c8f31e8..7e4b784033bfc 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleExecutor.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleExecutor.scala @@ -63,6 +63,14 @@ abstract class RuleExecutor[TreeType <: TreeNode[_]] extends Logging { /** Defines a sequence of rule batches, to be overridden by the implementation. */ protected def batches: Seq[Batch] + /** + * Defines a check function that checks for structural integrity of the plan after the execution + * of each rule. For example, we can check whether a plan is still resolved after each rule in + * `Optimizer`, so we can catch rules that return invalid plans. The check function returns + * `false` if the given plan doesn't pass the structural integrity check. + */ + protected def isPlanIntegral(plan: TreeType): Boolean = true + /** * Executes the batches of rules defined by the subclass. The batches are executed serially * using the defined execution strategy. Within each batch, rules are also executed serially. @@ -93,6 +101,13 @@ abstract class RuleExecutor[TreeType <: TreeNode[_]] extends Logging { """.stripMargin) } + // Run the structural integrity checker against the plan after each rule. + if (!isPlanIntegral(result)) { + val message = s"After applying rule ${rule.ruleName} in batch ${batch.name}, " + + "the structural integrity of the plan is broken." + throw new TreeNodeException(result, message, null) + } + result } iteration += 1 diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala index 0496d611ec3c7..b4c8eab19c5cc 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala @@ -25,7 +25,7 @@ import org.scalatest.prop.GeneratorDrivenPropertyChecks import org.apache.spark.{SparkConf, SparkFunSuite} import org.apache.spark.serializer.JavaSerializer import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} -import org.apache.spark.sql.catalyst.analysis.ResolveTimeZone +import org.apache.spark.sql.catalyst.analysis.{ResolveTimeZone, SimpleAnalyzer} import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.optimizer.SimpleTestOptimizer import org.apache.spark.sql.catalyst.plans.logical.{OneRowRelation, Project} @@ -188,7 +188,9 @@ trait ExpressionEvalHelper extends GeneratorDrivenPropertyChecks { expected: Any, inputRow: InternalRow = EmptyRow): Unit = { val plan = Project(Alias(expression, s"Optimized($expression)")() :: Nil, OneRowRelation()) - val optimizedPlan = SimpleTestOptimizer.execute(plan) + // We should analyze the plan first, otherwise we possibly optimize an unresolved plan. + val analyzedPlan = SimpleAnalyzer.execute(plan) + val optimizedPlan = SimpleTestOptimizer.execute(analyzedPlan) checkEvaluationWithoutCodegen(optimizedPlan.expressions.head, expected, inputRow) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizerStructuralIntegrityCheckerSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizerStructuralIntegrityCheckerSuite.scala new file mode 100644 index 0000000000000..6e183d81b7265 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizerStructuralIntegrityCheckerSuite.scala @@ -0,0 +1,60 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.optimizer + +import org.apache.spark.sql.catalyst.analysis.{EmptyFunctionRegistry, UnresolvedAttribute} +import org.apache.spark.sql.catalyst.catalog.{InMemoryCatalog, SessionCatalog} +import org.apache.spark.sql.catalyst.dsl.plans._ +import org.apache.spark.sql.catalyst.errors.TreeNodeException +import org.apache.spark.sql.catalyst.expressions.{Alias, Literal} +import org.apache.spark.sql.catalyst.plans.PlanTest +import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, OneRowRelation, Project} +import org.apache.spark.sql.catalyst.rules._ +import org.apache.spark.sql.internal.SQLConf + + +class OptimizerStructuralIntegrityCheckerSuite extends PlanTest { + + object OptimizeRuleBreakSI extends Rule[LogicalPlan] { + def apply(plan: LogicalPlan): LogicalPlan = plan transform { + case Project(projectList, child) => + val newAttr = UnresolvedAttribute("unresolvedAttr") + Project(projectList ++ Seq(newAttr), child) + } + } + + object Optimize extends Optimizer( + new SessionCatalog( + new InMemoryCatalog, + EmptyFunctionRegistry, + new SQLConf())) { + val newBatch = Batch("OptimizeRuleBreakSI", Once, OptimizeRuleBreakSI) + override def batches: Seq[Batch] = Seq(newBatch) ++ super.batches + } + + test("check for invalid plan after execution of rule") { + val analyzed = Project(Alias(Literal(10), "attr")() :: Nil, OneRowRelation()).analyze + assert(analyzed.resolved) + val message = intercept[TreeNodeException[LogicalPlan]] { + Optimize.execute(analyzed) + }.getMessage + val ruleName = OptimizeRuleBreakSI.ruleName + assert(message.contains(s"After applying rule $ruleName in batch OptimizeRuleBreakSI")) + assert(message.contains("the structural integrity of the plan is broken")) + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/RuleExecutorSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/RuleExecutorSuite.scala index c9d36910b0998..a67f54b263cc9 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/RuleExecutorSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/RuleExecutorSuite.scala @@ -56,4 +56,21 @@ class RuleExecutorSuite extends SparkFunSuite { }.getMessage assert(message.contains("Max iterations (10) reached for batch fixedPoint")) } + + test("structural integrity checker") { + object WithSIChecker extends RuleExecutor[Expression] { + override protected def isPlanIntegral(expr: Expression): Boolean = expr match { + case IntegerLiteral(_) => true + case _ => false + } + val batches = Batch("once", Once, DecrementLiterals) :: Nil + } + + assert(WithSIChecker.execute(Literal(10)) === Literal(9)) + + val message = intercept[TreeNodeException[LogicalPlan]] { + WithSIChecker.execute(Literal(10.1)) + }.getMessage + assert(message.contains("the structural integrity of the plan is broken")) + } }