Skip to content

Commit

Permalink
[SPARK-21092][SQL] Wire SQLConf in logical plan and expressions
Browse files Browse the repository at this point in the history
## What changes were proposed in this pull request?
It is really painful to not have configs in logical plan and expressions. We had to add all sorts of hacks (e.g. pass SQLConf explicitly in functions). This patch exposes SQLConf in logical plan, using a thread local variable and a getter closure that's set once there is an active SparkSession.

The implementation is a bit of a hack, since we didn't anticipate this need in the beginning (config was only exposed in physical plan). The implementation is described in `SQLConf.get`.

In terms of future work, we should follow up to clean up CBO (remove the need for passing in config).

## How was this patch tested?
Updated relevant tests for constraint propagation.

Author: Reynold Xin <rxin@databricks.com>

Closes apache#18299 from rxin/SPARK-21092.
  • Loading branch information
rxin authored and wangzejie committed Jun 16, 2017
1 parent 3c1f793 commit b897c60
Show file tree
Hide file tree
Showing 14 changed files with 141 additions and 108 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -77,12 +77,12 @@ abstract class Optimizer(sessionCatalog: SessionCatalog, conf: SQLConf)
// Operator push down
PushProjectionThroughUnion,
ReorderJoin(conf),
EliminateOuterJoin(conf),
EliminateOuterJoin,
PushPredicateThroughJoin,
PushDownPredicate,
LimitPushDown(conf),
ColumnPruning,
InferFiltersFromConstraints(conf),
InferFiltersFromConstraints,
// Operator combine
CollapseRepartition,
CollapseProject,
Expand All @@ -102,7 +102,7 @@ abstract class Optimizer(sessionCatalog: SessionCatalog, conf: SQLConf)
SimplifyConditionals,
RemoveDispensableExpressions,
SimplifyBinaryComparison,
PruneFilters(conf),
PruneFilters,
EliminateSorts,
SimplifyCasts,
SimplifyCaseConversionExpressions,
Expand Down Expand Up @@ -619,14 +619,15 @@ object CollapseWindow extends Rule[LogicalPlan] {
* Note: While this optimization is applicable to all types of join, it primarily benefits Inner and
* LeftSemi joins.
*/
case class InferFiltersFromConstraints(conf: SQLConf)
extends Rule[LogicalPlan] with PredicateHelper {
def apply(plan: LogicalPlan): LogicalPlan = if (conf.constraintPropagationEnabled) {
inferFilters(plan)
} else {
plan
}
object InferFiltersFromConstraints extends Rule[LogicalPlan] with PredicateHelper {

def apply(plan: LogicalPlan): LogicalPlan = {
if (SQLConf.get.constraintPropagationEnabled) {
inferFilters(plan)
} else {
plan
}
}

private def inferFilters(plan: LogicalPlan): LogicalPlan = plan transform {
case filter @ Filter(condition, child) =>
Expand Down Expand Up @@ -717,7 +718,7 @@ object EliminateSorts extends Rule[LogicalPlan] {
* 2) by substituting a dummy empty relation when the filter will always evaluate to `false`.
* 3) by eliminating the always-true conditions given the constraints on the child's output.
*/
case class PruneFilters(conf: SQLConf) extends Rule[LogicalPlan] with PredicateHelper {
object PruneFilters extends Rule[LogicalPlan] with PredicateHelper {
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
// If the filter condition always evaluate to true, remove the filter.
case Filter(Literal(true, BooleanType), child) => child
Expand All @@ -730,7 +731,7 @@ case class PruneFilters(conf: SQLConf) extends Rule[LogicalPlan] with PredicateH
case f @ Filter(fc, p: LogicalPlan) =>
val (prunedPredicates, remainingPredicates) =
splitConjunctivePredicates(fc).partition { cond =>
cond.deterministic && p.getConstraints(conf.constraintPropagationEnabled).contains(cond)
cond.deterministic && p.constraints.contains(cond)
}
if (prunedPredicates.isEmpty) {
f
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ case class ReorderJoin(conf: SQLConf) extends Rule[LogicalPlan] with PredicateHe
*
* This rule should be executed before pushing down the Filter
*/
case class EliminateOuterJoin(conf: SQLConf) extends Rule[LogicalPlan] with PredicateHelper {
object EliminateOuterJoin extends Rule[LogicalPlan] with PredicateHelper {

/**
* Returns whether the expression returns null or false when all inputs are nulls.
Expand All @@ -129,8 +129,7 @@ case class EliminateOuterJoin(conf: SQLConf) extends Rule[LogicalPlan] with Pred
}

private def buildNewJoinType(filter: Filter, join: Join): JoinType = {
val conditions = splitConjunctivePredicates(filter.condition) ++
filter.getConstraints(conf.constraintPropagationEnabled)
val conditions = splitConjunctivePredicates(filter.condition) ++ filter.constraints
val leftConditions = conditions.filter(_.references.subsetOf(join.left.outputSet))
val rightConditions = conditions.filter(_.references.subsetOf(join.right.outputSet))

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.plans

import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.trees.TreeNode
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.{DataType, StructType}

abstract class QueryPlan[PlanType <: QueryPlan[PlanType]]
Expand All @@ -27,6 +28,8 @@ abstract class QueryPlan[PlanType <: QueryPlan[PlanType]]

self: PlanType =>

def conf: SQLConf = SQLConf.get

def output: Seq[Attribute]

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,18 +27,20 @@ trait QueryPlanConstraints[PlanType <: QueryPlan[PlanType]] { self: QueryPlan[Pl
* example, if this set contains the expression `a = 2` then that expression is guaranteed to
* evaluate to `true` for all rows produced.
*/
lazy val constraints: ExpressionSet = ExpressionSet(getRelevantConstraints(validConstraints))

/**
* Returns [[constraints]] depending on the config of enabling constraint propagation. If the
* flag is disabled, simply returning an empty constraints.
*/
def getConstraints(constraintPropagationEnabled: Boolean): ExpressionSet =
if (constraintPropagationEnabled) {
constraints
lazy val constraints: ExpressionSet = {
if (conf.constraintPropagationEnabled) {
ExpressionSet(
validConstraints
.union(inferAdditionalConstraints(validConstraints))
.union(constructIsNotNullConstraints(validConstraints))
.filter { c =>
c.references.nonEmpty && c.references.subsetOf(outputSet) && c.deterministic
}
)
} else {
ExpressionSet(Set.empty)
}
}

/**
* This method can be overridden by any child class of QueryPlan to specify a set of constraints
Expand All @@ -50,19 +52,6 @@ trait QueryPlanConstraints[PlanType <: QueryPlan[PlanType]] { self: QueryPlan[Pl
*/
protected def validConstraints: Set[Expression] = Set.empty

/**
* Extracts the relevant constraints from a given set of constraints based on the attributes that
* appear in the [[outputSet]].
*/
protected def getRelevantConstraints(constraints: Set[Expression]): Set[Expression] = {
constraints
.union(inferAdditionalConstraints(constraints))
.union(constructIsNotNullConstraints(constraints))
.filter(constraint =>
constraint.references.nonEmpty && constraint.references.subsetOf(outputSet) &&
constraint.deterministic)
}

/**
* Infers a set of `isNotNull` constraints from null intolerant expressions as well as
* non-nullable attributes. For e.g., if an expression is of the form (`a > 5`), this
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package org.apache.spark.sql.internal

import java.util.{Locale, NoSuchElementException, Properties, TimeZone}
import java.util.concurrent.TimeUnit
import java.util.concurrent.atomic.AtomicReference

import scala.collection.JavaConverters._
import scala.collection.immutable
Expand Down Expand Up @@ -64,6 +65,47 @@ object SQLConf {
}
}

/**
* Default config. Only used when there is no active SparkSession for the thread.
* See [[get]] for more information.
*/
private val fallbackConf = new ThreadLocal[SQLConf] {
override def initialValue: SQLConf = new SQLConf
}

/** See [[get]] for more information. */
def getFallbackConf: SQLConf = fallbackConf.get()

/**
* Defines a getter that returns the SQLConf within scope.
* See [[get]] for more information.
*/
private val confGetter = new AtomicReference[() => SQLConf](() => fallbackConf.get())

/**
* Sets the active config object within the current scope.
* See [[get]] for more information.
*/
def setSQLConfGetter(getter: () => SQLConf): Unit = {
confGetter.set(getter)
}

/**
* Returns the active config object within the current scope. If there is an active SparkSession,
* the proper SQLConf associated with the thread's session is used.
*
* The way this works is a little bit convoluted, due to the fact that config was added initially
* only for physical plans (and as a result not in sql/catalyst module).
*
* The first time a SparkSession is instantiated, we set the [[confGetter]] to return the
* active SparkSession's config. If there is no active SparkSession, it returns using the thread
* local [[fallbackConf]]. The reason [[fallbackConf]] is a thread local (rather than just a conf)
* is to support setting different config options for different threads so we can potentially
* run tests in parallel. At the time this feature was implemented, this was a no-op since we
* run unit tests (that does not involve SparkSession) in serial order.
*/
def get: SQLConf = confGetter.get()()

val OPTIMIZER_MAX_ITERATIONS = buildConf("spark.sql.optimizer.maxIterations")
.internal()
.doc("The max number of iterations the optimizer and analyzer runs.")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ class BinaryComparisonSimplificationSuite extends PlanTest with PredicateHelper
ConstantFolding,
BooleanSimplification,
SimplifyBinaryComparison,
PruneFilters(conf)) :: Nil
PruneFilters) :: Nil
}

val nullableRelation = LocalRelation('a.int.withNullability(true))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ class BooleanSimplificationSuite extends PlanTest with PredicateHelper {
NullPropagation(conf),
ConstantFolding,
BooleanSimplification,
PruneFilters(conf)) :: Nil
PruneFilters) :: Nil
}

val testRelation = LocalRelation('a.int, 'b.int, 'c.int, 'd.string)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules._
import org.apache.spark.sql.internal.SQLConf.CONSTRAINT_PROPAGATION_ENABLED
import org.apache.spark.sql.internal.SQLConf

class InferFiltersFromConstraintsSuite extends PlanTest {

Expand All @@ -32,20 +32,11 @@ class InferFiltersFromConstraintsSuite extends PlanTest {
Batch("InferAndPushDownFilters", FixedPoint(100),
PushPredicateThroughJoin,
PushDownPredicate,
InferFiltersFromConstraints(conf),
InferFiltersFromConstraints,
CombineFilters,
BooleanSimplification) :: Nil
}

object OptimizeWithConstraintPropagationDisabled extends RuleExecutor[LogicalPlan] {
val batches =
Batch("InferAndPushDownFilters", FixedPoint(100),
PushPredicateThroughJoin,
PushDownPredicate,
InferFiltersFromConstraints(conf.copy(CONSTRAINT_PROPAGATION_ENABLED -> false)),
CombineFilters) :: Nil
}

val testRelation = LocalRelation('a.int, 'b.int, 'c.int)

test("filter: filter out constraints in condition") {
Expand Down Expand Up @@ -215,8 +206,13 @@ class InferFiltersFromConstraintsSuite extends PlanTest {
}

test("No inferred filter when constraint propagation is disabled") {
val originalQuery = testRelation.where('a === 1 && 'a === 'b).analyze
val optimized = OptimizeWithConstraintPropagationDisabled.execute(originalQuery)
comparePlans(optimized, originalQuery)
try {
SQLConf.get.setConf(SQLConf.CONSTRAINT_PROPAGATION_ENABLED, false)
val originalQuery = testRelation.where('a === 1 && 'a === 'b).analyze
val optimized = Optimize.execute(originalQuery)
comparePlans(optimized, originalQuery)
} finally {
SQLConf.get.unsetConf(SQLConf.CONSTRAINT_PROPAGATION_ENABLED)
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -24,24 +24,15 @@ import org.apache.spark.sql.catalyst.expressions.{Coalesce, IsNotNull}
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules._
import org.apache.spark.sql.internal.SQLConf.CONSTRAINT_PROPAGATION_ENABLED
import org.apache.spark.sql.internal.SQLConf

class OuterJoinEliminationSuite extends PlanTest {
object Optimize extends RuleExecutor[LogicalPlan] {
val batches =
Batch("Subqueries", Once,
EliminateSubqueryAliases) ::
Batch("Outer Join Elimination", Once,
EliminateOuterJoin(conf),
PushPredicateThroughJoin) :: Nil
}

object OptimizeWithConstraintPropagationDisabled extends RuleExecutor[LogicalPlan] {
val batches =
Batch("Subqueries", Once,
EliminateSubqueryAliases) ::
Batch("Outer Join Elimination", Once,
EliminateOuterJoin(conf.copy(CONSTRAINT_PROPAGATION_ENABLED -> false)),
EliminateOuterJoin,
PushPredicateThroughJoin) :: Nil
}

Expand Down Expand Up @@ -243,19 +234,25 @@ class OuterJoinEliminationSuite extends PlanTest {
}

test("no outer join elimination if constraint propagation is disabled") {
val x = testRelation.subquery('x)
val y = testRelation1.subquery('y)
try {
SQLConf.get.setConf(SQLConf.CONSTRAINT_PROPAGATION_ENABLED, false)

// The predicate "x.b + y.d >= 3" will be inferred constraints like:
// "x.b != null" and "y.d != null", if constraint propagation is enabled.
// When we disable it, the predicate can't be evaluated on left or right plan and used to
// filter out nulls. So the Outer Join will not be eliminated.
val originalQuery =
val x = testRelation.subquery('x)
val y = testRelation1.subquery('y)

// The predicate "x.b + y.d >= 3" will be inferred constraints like:
// "x.b != null" and "y.d != null", if constraint propagation is enabled.
// When we disable it, the predicate can't be evaluated on left or right plan and used to
// filter out nulls. So the Outer Join will not be eliminated.
val originalQuery =
x.join(y, FullOuter, Option("x.a".attr === "y.d".attr))
.where("x.b".attr + "y.d".attr >= 3)

val optimized = OptimizeWithConstraintPropagationDisabled.execute(originalQuery.analyze)
val optimized = Optimize.execute(originalQuery.analyze)

comparePlans(optimized, originalQuery.analyze)
comparePlans(optimized, originalQuery.analyze)
} finally {
SQLConf.get.unsetConf(SQLConf.CONSTRAINT_PROPAGATION_ENABLED)
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ class PropagateEmptyRelationSuite extends PlanTest {
ReplaceExceptWithAntiJoin,
ReplaceIntersectWithSemiJoin,
PushDownPredicate,
PruneFilters(conf),
PruneFilters,
PropagateEmptyRelation) :: Nil
}

Expand All @@ -45,7 +45,7 @@ class PropagateEmptyRelationSuite extends PlanTest {
ReplaceExceptWithAntiJoin,
ReplaceIntersectWithSemiJoin,
PushDownPredicate,
PruneFilters(conf)) :: Nil
PruneFilters) :: Nil
}

val testRelation1 = LocalRelation.fromExternalRows(Seq('a.int), data = Seq(Row(1)))
Expand Down
Loading

0 comments on commit b897c60

Please sign in to comment.