From b897c600ac197ead2fc286ad0e1c487d34634bb1 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Wed, 14 Jun 2017 22:11:41 -0700 Subject: [PATCH] [SPARK-21092][SQL] Wire SQLConf in logical plan and expressions ## What changes were proposed in this pull request? It is really painful to not have configs in logical plan and expressions. We had to add all sorts of hacks (e.g. pass SQLConf explicitly in functions). This patch exposes SQLConf in logical plan, using a thread local variable and a getter closure that's set once there is an active SparkSession. The implementation is a bit of a hack, since we didn't anticipate this need in the beginning (config was only exposed in physical plan). The implementation is described in `SQLConf.get`. In terms of future work, we should follow up to clean up CBO (remove the need for passing in config). ## How was this patch tested? Updated relevant tests for constraint propagation. Author: Reynold Xin Closes #18299 from rxin/SPARK-21092. --- .../sql/catalyst/optimizer/Optimizer.scala | 25 +++++------ .../spark/sql/catalyst/optimizer/joins.scala | 5 +-- .../spark/sql/catalyst/plans/QueryPlan.scala | 3 ++ .../catalyst/plans/QueryPlanConstraints.scala | 33 +++++---------- .../apache/spark/sql/internal/SQLConf.scala | 42 +++++++++++++++++++ .../BinaryComparisonSimplificationSuite.scala | 2 +- .../BooleanSimplificationSuite.scala | 2 +- .../InferFiltersFromConstraintsSuite.scala | 24 +++++------ .../optimizer/OuterJoinEliminationSuite.scala | 37 ++++++++-------- .../PropagateEmptyRelationSuite.scala | 4 +- .../optimizer/PruneFiltersSuite.scala | 36 +++++++--------- .../optimizer/SetOperationSuite.scala | 2 +- .../plans/ConstraintPropagationSuite.scala | 29 ++++++++----- .../org/apache/spark/sql/SparkSession.scala | 5 +++ 14 files changed, 141 insertions(+), 108 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index d16689a34298a..3ab70fb90470c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -77,12 +77,12 @@ abstract class Optimizer(sessionCatalog: SessionCatalog, conf: SQLConf) // Operator push down PushProjectionThroughUnion, ReorderJoin(conf), - EliminateOuterJoin(conf), + EliminateOuterJoin, PushPredicateThroughJoin, PushDownPredicate, LimitPushDown(conf), ColumnPruning, - InferFiltersFromConstraints(conf), + InferFiltersFromConstraints, // Operator combine CollapseRepartition, CollapseProject, @@ -102,7 +102,7 @@ abstract class Optimizer(sessionCatalog: SessionCatalog, conf: SQLConf) SimplifyConditionals, RemoveDispensableExpressions, SimplifyBinaryComparison, - PruneFilters(conf), + PruneFilters, EliminateSorts, SimplifyCasts, SimplifyCaseConversionExpressions, @@ -619,14 +619,15 @@ object CollapseWindow extends Rule[LogicalPlan] { * Note: While this optimization is applicable to all types of join, it primarily benefits Inner and * LeftSemi joins. */ -case class InferFiltersFromConstraints(conf: SQLConf) - extends Rule[LogicalPlan] with PredicateHelper { - def apply(plan: LogicalPlan): LogicalPlan = if (conf.constraintPropagationEnabled) { - inferFilters(plan) - } else { - plan - } +object InferFiltersFromConstraints extends Rule[LogicalPlan] with PredicateHelper { + def apply(plan: LogicalPlan): LogicalPlan = { + if (SQLConf.get.constraintPropagationEnabled) { + inferFilters(plan) + } else { + plan + } + } private def inferFilters(plan: LogicalPlan): LogicalPlan = plan transform { case filter @ Filter(condition, child) => @@ -717,7 +718,7 @@ object EliminateSorts extends Rule[LogicalPlan] { * 2) by substituting a dummy empty relation when the filter will always evaluate to `false`. * 3) by eliminating the always-true conditions given the constraints on the child's output. */ -case class PruneFilters(conf: SQLConf) extends Rule[LogicalPlan] with PredicateHelper { +object PruneFilters extends Rule[LogicalPlan] with PredicateHelper { def apply(plan: LogicalPlan): LogicalPlan = plan transform { // If the filter condition always evaluate to true, remove the filter. case Filter(Literal(true, BooleanType), child) => child @@ -730,7 +731,7 @@ case class PruneFilters(conf: SQLConf) extends Rule[LogicalPlan] with PredicateH case f @ Filter(fc, p: LogicalPlan) => val (prunedPredicates, remainingPredicates) = splitConjunctivePredicates(fc).partition { cond => - cond.deterministic && p.getConstraints(conf.constraintPropagationEnabled).contains(cond) + cond.deterministic && p.constraints.contains(cond) } if (prunedPredicates.isEmpty) { f diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joins.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joins.scala index 2fe3039774423..bb97e2c808b9f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joins.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joins.scala @@ -113,7 +113,7 @@ case class ReorderJoin(conf: SQLConf) extends Rule[LogicalPlan] with PredicateHe * * This rule should be executed before pushing down the Filter */ -case class EliminateOuterJoin(conf: SQLConf) extends Rule[LogicalPlan] with PredicateHelper { +object EliminateOuterJoin extends Rule[LogicalPlan] with PredicateHelper { /** * Returns whether the expression returns null or false when all inputs are nulls. @@ -129,8 +129,7 @@ case class EliminateOuterJoin(conf: SQLConf) extends Rule[LogicalPlan] with Pred } private def buildNewJoinType(filter: Filter, join: Join): JoinType = { - val conditions = splitConjunctivePredicates(filter.condition) ++ - filter.getConstraints(conf.constraintPropagationEnabled) + val conditions = splitConjunctivePredicates(filter.condition) ++ filter.constraints val leftConditions = conditions.filter(_.references.subsetOf(join.left.outputSet)) val rightConditions = conditions.filter(_.references.subsetOf(join.right.outputSet)) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala index 8bc462e1e72c9..9130b14763e24 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.plans import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.trees.TreeNode +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.{DataType, StructType} abstract class QueryPlan[PlanType <: QueryPlan[PlanType]] @@ -27,6 +28,8 @@ abstract class QueryPlan[PlanType <: QueryPlan[PlanType]] self: PlanType => + def conf: SQLConf = SQLConf.get + def output: Seq[Attribute] /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlanConstraints.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlanConstraints.scala index 7d8a17d97759c..b08a009f0dca1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlanConstraints.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlanConstraints.scala @@ -27,18 +27,20 @@ trait QueryPlanConstraints[PlanType <: QueryPlan[PlanType]] { self: QueryPlan[Pl * example, if this set contains the expression `a = 2` then that expression is guaranteed to * evaluate to `true` for all rows produced. */ - lazy val constraints: ExpressionSet = ExpressionSet(getRelevantConstraints(validConstraints)) - - /** - * Returns [[constraints]] depending on the config of enabling constraint propagation. If the - * flag is disabled, simply returning an empty constraints. - */ - def getConstraints(constraintPropagationEnabled: Boolean): ExpressionSet = - if (constraintPropagationEnabled) { - constraints + lazy val constraints: ExpressionSet = { + if (conf.constraintPropagationEnabled) { + ExpressionSet( + validConstraints + .union(inferAdditionalConstraints(validConstraints)) + .union(constructIsNotNullConstraints(validConstraints)) + .filter { c => + c.references.nonEmpty && c.references.subsetOf(outputSet) && c.deterministic + } + ) } else { ExpressionSet(Set.empty) } + } /** * This method can be overridden by any child class of QueryPlan to specify a set of constraints @@ -50,19 +52,6 @@ trait QueryPlanConstraints[PlanType <: QueryPlan[PlanType]] { self: QueryPlan[Pl */ protected def validConstraints: Set[Expression] = Set.empty - /** - * Extracts the relevant constraints from a given set of constraints based on the attributes that - * appear in the [[outputSet]]. - */ - protected def getRelevantConstraints(constraints: Set[Expression]): Set[Expression] = { - constraints - .union(inferAdditionalConstraints(constraints)) - .union(constructIsNotNullConstraints(constraints)) - .filter(constraint => - constraint.references.nonEmpty && constraint.references.subsetOf(outputSet) && - constraint.deterministic) - } - /** * Infers a set of `isNotNull` constraints from null intolerant expressions as well as * non-nullable attributes. For e.g., if an expression is of the form (`a > 5`), this diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 9f7c760fb9d21..6ab3a615e6cc0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.internal import java.util.{Locale, NoSuchElementException, Properties, TimeZone} import java.util.concurrent.TimeUnit +import java.util.concurrent.atomic.AtomicReference import scala.collection.JavaConverters._ import scala.collection.immutable @@ -64,6 +65,47 @@ object SQLConf { } } + /** + * Default config. Only used when there is no active SparkSession for the thread. + * See [[get]] for more information. + */ + private val fallbackConf = new ThreadLocal[SQLConf] { + override def initialValue: SQLConf = new SQLConf + } + + /** See [[get]] for more information. */ + def getFallbackConf: SQLConf = fallbackConf.get() + + /** + * Defines a getter that returns the SQLConf within scope. + * See [[get]] for more information. + */ + private val confGetter = new AtomicReference[() => SQLConf](() => fallbackConf.get()) + + /** + * Sets the active config object within the current scope. + * See [[get]] for more information. + */ + def setSQLConfGetter(getter: () => SQLConf): Unit = { + confGetter.set(getter) + } + + /** + * Returns the active config object within the current scope. If there is an active SparkSession, + * the proper SQLConf associated with the thread's session is used. + * + * The way this works is a little bit convoluted, due to the fact that config was added initially + * only for physical plans (and as a result not in sql/catalyst module). + * + * The first time a SparkSession is instantiated, we set the [[confGetter]] to return the + * active SparkSession's config. If there is no active SparkSession, it returns using the thread + * local [[fallbackConf]]. The reason [[fallbackConf]] is a thread local (rather than just a conf) + * is to support setting different config options for different threads so we can potentially + * run tests in parallel. At the time this feature was implemented, this was a no-op since we + * run unit tests (that does not involve SparkSession) in serial order. + */ + def get: SQLConf = confGetter.get()() + val OPTIMIZER_MAX_ITERATIONS = buildConf("spark.sql.optimizer.maxIterations") .internal() .doc("The max number of iterations the optimizer and analyzer runs.") diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/BinaryComparisonSimplificationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/BinaryComparisonSimplificationSuite.scala index b29e1cbd14943..2a04bd588dc1d 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/BinaryComparisonSimplificationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/BinaryComparisonSimplificationSuite.scala @@ -37,7 +37,7 @@ class BinaryComparisonSimplificationSuite extends PlanTest with PredicateHelper ConstantFolding, BooleanSimplification, SimplifyBinaryComparison, - PruneFilters(conf)) :: Nil + PruneFilters) :: Nil } val nullableRelation = LocalRelation('a.int.withNullability(true)) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/BooleanSimplificationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/BooleanSimplificationSuite.scala index c275f997ba6e9..1df0a89cf0bf1 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/BooleanSimplificationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/BooleanSimplificationSuite.scala @@ -38,7 +38,7 @@ class BooleanSimplificationSuite extends PlanTest with PredicateHelper { NullPropagation(conf), ConstantFolding, BooleanSimplification, - PruneFilters(conf)) :: Nil + PruneFilters) :: Nil } val testRelation = LocalRelation('a.int, 'b.int, 'c.int, 'd.string) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromConstraintsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromConstraintsSuite.scala index 9a4bcdb011435..cdc9f25cf8777 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromConstraintsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromConstraintsSuite.scala @@ -23,7 +23,7 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules._ -import org.apache.spark.sql.internal.SQLConf.CONSTRAINT_PROPAGATION_ENABLED +import org.apache.spark.sql.internal.SQLConf class InferFiltersFromConstraintsSuite extends PlanTest { @@ -32,20 +32,11 @@ class InferFiltersFromConstraintsSuite extends PlanTest { Batch("InferAndPushDownFilters", FixedPoint(100), PushPredicateThroughJoin, PushDownPredicate, - InferFiltersFromConstraints(conf), + InferFiltersFromConstraints, CombineFilters, BooleanSimplification) :: Nil } - object OptimizeWithConstraintPropagationDisabled extends RuleExecutor[LogicalPlan] { - val batches = - Batch("InferAndPushDownFilters", FixedPoint(100), - PushPredicateThroughJoin, - PushDownPredicate, - InferFiltersFromConstraints(conf.copy(CONSTRAINT_PROPAGATION_ENABLED -> false)), - CombineFilters) :: Nil - } - val testRelation = LocalRelation('a.int, 'b.int, 'c.int) test("filter: filter out constraints in condition") { @@ -215,8 +206,13 @@ class InferFiltersFromConstraintsSuite extends PlanTest { } test("No inferred filter when constraint propagation is disabled") { - val originalQuery = testRelation.where('a === 1 && 'a === 'b).analyze - val optimized = OptimizeWithConstraintPropagationDisabled.execute(originalQuery) - comparePlans(optimized, originalQuery) + try { + SQLConf.get.setConf(SQLConf.CONSTRAINT_PROPAGATION_ENABLED, false) + val originalQuery = testRelation.where('a === 1 && 'a === 'b).analyze + val optimized = Optimize.execute(originalQuery) + comparePlans(optimized, originalQuery) + } finally { + SQLConf.get.unsetConf(SQLConf.CONSTRAINT_PROPAGATION_ENABLED) + } } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OuterJoinEliminationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OuterJoinEliminationSuite.scala index b7136703b7541..a37bc4bca2422 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OuterJoinEliminationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OuterJoinEliminationSuite.scala @@ -24,7 +24,7 @@ import org.apache.spark.sql.catalyst.expressions.{Coalesce, IsNotNull} import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules._ -import org.apache.spark.sql.internal.SQLConf.CONSTRAINT_PROPAGATION_ENABLED +import org.apache.spark.sql.internal.SQLConf class OuterJoinEliminationSuite extends PlanTest { object Optimize extends RuleExecutor[LogicalPlan] { @@ -32,16 +32,7 @@ class OuterJoinEliminationSuite extends PlanTest { Batch("Subqueries", Once, EliminateSubqueryAliases) :: Batch("Outer Join Elimination", Once, - EliminateOuterJoin(conf), - PushPredicateThroughJoin) :: Nil - } - - object OptimizeWithConstraintPropagationDisabled extends RuleExecutor[LogicalPlan] { - val batches = - Batch("Subqueries", Once, - EliminateSubqueryAliases) :: - Batch("Outer Join Elimination", Once, - EliminateOuterJoin(conf.copy(CONSTRAINT_PROPAGATION_ENABLED -> false)), + EliminateOuterJoin, PushPredicateThroughJoin) :: Nil } @@ -243,19 +234,25 @@ class OuterJoinEliminationSuite extends PlanTest { } test("no outer join elimination if constraint propagation is disabled") { - val x = testRelation.subquery('x) - val y = testRelation1.subquery('y) + try { + SQLConf.get.setConf(SQLConf.CONSTRAINT_PROPAGATION_ENABLED, false) - // The predicate "x.b + y.d >= 3" will be inferred constraints like: - // "x.b != null" and "y.d != null", if constraint propagation is enabled. - // When we disable it, the predicate can't be evaluated on left or right plan and used to - // filter out nulls. So the Outer Join will not be eliminated. - val originalQuery = + val x = testRelation.subquery('x) + val y = testRelation1.subquery('y) + + // The predicate "x.b + y.d >= 3" will be inferred constraints like: + // "x.b != null" and "y.d != null", if constraint propagation is enabled. + // When we disable it, the predicate can't be evaluated on left or right plan and used to + // filter out nulls. So the Outer Join will not be eliminated. + val originalQuery = x.join(y, FullOuter, Option("x.a".attr === "y.d".attr)) .where("x.b".attr + "y.d".attr >= 3) - val optimized = OptimizeWithConstraintPropagationDisabled.execute(originalQuery.analyze) + val optimized = Optimize.execute(originalQuery.analyze) - comparePlans(optimized, originalQuery.analyze) + comparePlans(optimized, originalQuery.analyze) + } finally { + SQLConf.get.unsetConf(SQLConf.CONSTRAINT_PROPAGATION_ENABLED) + } } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PropagateEmptyRelationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PropagateEmptyRelationSuite.scala index 38dff4733f714..2285be16938d6 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PropagateEmptyRelationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PropagateEmptyRelationSuite.scala @@ -33,7 +33,7 @@ class PropagateEmptyRelationSuite extends PlanTest { ReplaceExceptWithAntiJoin, ReplaceIntersectWithSemiJoin, PushDownPredicate, - PruneFilters(conf), + PruneFilters, PropagateEmptyRelation) :: Nil } @@ -45,7 +45,7 @@ class PropagateEmptyRelationSuite extends PlanTest { ReplaceExceptWithAntiJoin, ReplaceIntersectWithSemiJoin, PushDownPredicate, - PruneFilters(conf)) :: Nil + PruneFilters) :: Nil } val testRelation1 = LocalRelation.fromExternalRows(Seq('a.int), data = Seq(Row(1))) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PruneFiltersSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PruneFiltersSuite.scala index 741dd0cf428d0..706634cdd29b8 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PruneFiltersSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PruneFiltersSuite.scala @@ -24,6 +24,7 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules._ +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.internal.SQLConf.CONSTRAINT_PROPAGATION_ENABLED class PruneFiltersSuite extends PlanTest { @@ -34,18 +35,7 @@ class PruneFiltersSuite extends PlanTest { EliminateSubqueryAliases) :: Batch("Filter Pushdown and Pruning", Once, CombineFilters, - PruneFilters(conf), - PushDownPredicate, - PushPredicateThroughJoin) :: Nil - } - - object OptimizeWithConstraintPropagationDisabled extends RuleExecutor[LogicalPlan] { - val batches = - Batch("Subqueries", Once, - EliminateSubqueryAliases) :: - Batch("Filter Pushdown and Pruning", Once, - CombineFilters, - PruneFilters(conf.copy(CONSTRAINT_PROPAGATION_ENABLED -> false)), + PruneFilters, PushDownPredicate, PushPredicateThroughJoin) :: Nil } @@ -159,15 +149,19 @@ class PruneFiltersSuite extends PlanTest { ("tr1.a".attr > 10 || "tr1.c".attr < 10) && 'd.attr < 100) - val optimized = - OptimizeWithConstraintPropagationDisabled.execute(queryWithUselessFilter.analyze) - // When constraint propagation is disabled, the useless filter won't be pruned. - // It gets pushed down. Because the rule `CombineFilters` runs only once, there are redundant - // and duplicate filters. - val correctAnswer = tr1 - .where("tr1.a".attr > 10 || "tr1.c".attr < 10).where("tr1.a".attr > 10 || "tr1.c".attr < 10) - .join(tr2.where('d.attr < 100).where('d.attr < 100), + SQLConf.get.setConf(SQLConf.CONSTRAINT_PROPAGATION_ENABLED, false) + try { + val optimized = Optimize.execute(queryWithUselessFilter.analyze) + // When constraint propagation is disabled, the useless filter won't be pruned. + // It gets pushed down. Because the rule `CombineFilters` runs only once, there are redundant + // and duplicate filters. + val correctAnswer = tr1 + .where("tr1.a".attr > 10 || "tr1.c".attr < 10).where("tr1.a".attr > 10 || "tr1.c".attr < 10) + .join(tr2.where('d.attr < 100).where('d.attr < 100), Inner, Some("tr1.a".attr === "tr2.a".attr)).analyze - comparePlans(optimized, correctAnswer) + comparePlans(optimized, correctAnswer) + } finally { + SQLConf.get.unsetConf(SQLConf.CONSTRAINT_PROPAGATION_ENABLED) + } } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SetOperationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SetOperationSuite.scala index 756e0f35b2178..21b7f49e14bd5 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SetOperationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SetOperationSuite.scala @@ -34,7 +34,7 @@ class SetOperationSuite extends PlanTest { CombineUnions, PushProjectionThroughUnion, PushDownPredicate, - PruneFilters(conf)) :: Nil + PruneFilters) :: Nil } val testRelation = LocalRelation('a.int, 'b.int, 'c.int) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala index 4061394b862a6..a3948d90b0e4d 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala @@ -25,6 +25,7 @@ import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.{DataType, DoubleType, IntegerType, LongType, StringType} class ConstraintPropagationSuite extends SparkFunSuite { @@ -399,20 +400,26 @@ class ConstraintPropagationSuite extends SparkFunSuite { } test("enable/disable constraint propagation") { - val tr = LocalRelation('a.int, 'b.string, 'c.int) - val filterRelation = tr.where('a.attr > 10) + try { + val tr = LocalRelation('a.int, 'b.string, 'c.int) + val filterRelation = tr.where('a.attr > 10) - verifyConstraints( - filterRelation.analyze.getConstraints(constraintPropagationEnabled = true), - filterRelation.analyze.constraints) + SQLConf.get.setConf(SQLConf.CONSTRAINT_PROPAGATION_ENABLED, true) + assert(filterRelation.analyze.constraints.nonEmpty) - assert(filterRelation.analyze.getConstraints(constraintPropagationEnabled = false).isEmpty) + SQLConf.get.setConf(SQLConf.CONSTRAINT_PROPAGATION_ENABLED, false) + assert(filterRelation.analyze.constraints.isEmpty) - val aliasedRelation = tr.where('c.attr > 10 && 'a.attr < 5) - .groupBy('a, 'c, 'b)('a, 'c.as("c1"), count('a).as("a3")).select('c1, 'a, 'a3) + val aliasedRelation = tr.where('c.attr > 10 && 'a.attr < 5) + .groupBy('a, 'c, 'b)('a, 'c.as("c1"), count('a).as("a3")).select('c1, 'a, 'a3) - verifyConstraints(aliasedRelation.analyze.getConstraints(constraintPropagationEnabled = true), - aliasedRelation.analyze.constraints) - assert(aliasedRelation.analyze.getConstraints(constraintPropagationEnabled = false).isEmpty) + SQLConf.get.setConf(SQLConf.CONSTRAINT_PROPAGATION_ENABLED, true) + assert(aliasedRelation.analyze.constraints.nonEmpty) + + SQLConf.get.setConf(SQLConf.CONSTRAINT_PROPAGATION_ENABLED, false) + assert(aliasedRelation.analyze.constraints.isEmpty) + } finally { + SQLConf.get.unsetConf(SQLConf.CONSTRAINT_PROPAGATION_ENABLED) + } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala index d2bf350711936..2c38f7d7c88da 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala @@ -87,6 +87,11 @@ class SparkSession private( sparkContext.assertNotStopped() + // If there is no active SparkSession, uses the default SQL conf. Otherwise, use the session's. + SQLConf.setSQLConfGetter(() => { + SparkSession.getActiveSession.map(_.sessionState.conf).getOrElse(SQLConf.getFallbackConf) + }) + /** * The version of Spark on which this application is running. *