Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-32755][SQL] Maintain the order of expressions in AttributeSet and ExpressionSet #29598

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
The table of contents is too big for display.
Diff view
Diff view
  •  
  •  
  •  
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,10 @@ object ExpressionSet {
expressions.foreach(set.add)
set
}

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should apply the same change in ExpressionSet under the scala-2.13 source tree. @dbaliafroozeh can you open a followup PR?

Copy link
Contributor Author

@dbaliafroozeh dbaliafroozeh Sep 4, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@cloud-fan good catch, I thought I already deleted the ExpressionSet in 2.13. Note that we don't want it anymore as ExpressionSet doesn't extend Set anymore. I'll open a followup PR for that.

def apply(): ExpressionSet = {
new ExpressionSet()
}
}

/**
Expand All @@ -53,46 +57,102 @@ object ExpressionSet {
* This is consistent with how we define `semanticEquals` between two expressions.
*/
class ExpressionSet protected(
protected val baseSet: mutable.Set[Expression] = new mutable.HashSet,
protected val originals: mutable.Buffer[Expression] = new ArrayBuffer)
extends Set[Expression] {
private val baseSet: mutable.Set[Expression] = new mutable.HashSet,
private val originals: mutable.Buffer[Expression] = new ArrayBuffer)
extends Iterable[Expression] {

// Note: this class supports Scala 2.12. A parallel source tree has a 2.13 implementation.

protected def add(e: Expression): Unit = {
if (!e.deterministic) {
originals += e
} else if (!baseSet.contains(e.canonicalized) ) {
} else if (!baseSet.contains(e.canonicalized)) {
baseSet.add(e.canonicalized)
originals += e
}
}

override def contains(elem: Expression): Boolean = baseSet.contains(elem.canonicalized)
protected def remove(e: Expression): Unit = {
if (e.deterministic) {
baseSet --= baseSet.filter(_ == e.canonicalized)
originals --= originals.filter(_.canonicalized == e.canonicalized)
}
}

def contains(elem: Expression): Boolean = baseSet.contains(elem.canonicalized)

override def filter(p: Expression => Boolean): ExpressionSet = {
val newBaseSet = baseSet.filter(e => p(e.canonicalized))
val newOriginals = originals.filter(e => p(e.canonicalized))
new ExpressionSet(newBaseSet, newOriginals)
}

override def filterNot(p: Expression => Boolean): ExpressionSet = {
val newBaseSet = baseSet.filterNot(e => p(e.canonicalized))
val newOriginals = originals.filterNot(e => p(e.canonicalized))
new ExpressionSet(newBaseSet, newOriginals)
}

override def +(elem: Expression): ExpressionSet = {
val newSet = new ExpressionSet(baseSet.clone(), originals.clone())
def +(elem: Expression): ExpressionSet = {
val newSet = clone()
newSet.add(elem)
newSet
}

override def ++(elems: GenTraversableOnce[Expression]): ExpressionSet = {
val newSet = new ExpressionSet(baseSet.clone(), originals.clone())
def ++(elems: GenTraversableOnce[Expression]): ExpressionSet = {
val newSet = clone()
elems.foreach(newSet.add)
newSet
}

override def -(elem: Expression): ExpressionSet = {
if (elem.deterministic) {
val newBaseSet = baseSet.clone().filterNot(_ == elem.canonicalized)
val newOriginals = originals.clone().filterNot(_.canonicalized == elem.canonicalized)
new ExpressionSet(newBaseSet, newOriginals)
} else {
new ExpressionSet(baseSet.clone(), originals.clone())
}
def -(elem: Expression): ExpressionSet = {
val newSet = clone()
newSet.remove(elem)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Isn't this more efficient?:

ExpressionSet(baseSet.filter(_ != e. canonicalized), originals.filter(_.canonicalized != e.canonicalized))

newSet
}

def --(elems: GenTraversableOnce[Expression]): ExpressionSet = {
val newSet = clone()
elems.foreach(newSet.remove)
newSet
}

override def iterator: Iterator[Expression] = originals.iterator
def map(f: Expression => Expression): ExpressionSet = {
val newSet = new ExpressionSet()
this.iterator.foreach(elem => newSet.add(f(elem)))
newSet
}

def flatMap(f: Expression => Iterable[Expression]): ExpressionSet = {
val newSet = new ExpressionSet()
this.iterator.foreach(f(_).foreach(newSet.add))
newSet
}

def iterator: Iterator[Expression] = originals.iterator

def union(that: ExpressionSet): ExpressionSet = {
val newSet = clone()
that.iterator.foreach(newSet.add)
newSet
}

def subsetOf(that: ExpressionSet): Boolean = this.iterator.forall(that.contains)

def intersect(that: ExpressionSet): ExpressionSet = this.filter(that.contains)

def diff(that: ExpressionSet): ExpressionSet = this -- that

def apply(elem: Expression): Boolean = this.contains(elem)

override def equals(obj: Any): Boolean = obj match {
case other: ExpressionSet => this.baseSet == other.baseSet
case _ => false
}

override def hashCode(): Int = baseSet.hashCode()

override def clone(): ExpressionSet = new ExpressionSet(baseSet.clone(), originals.clone())

/**
* Returns a string containing both the post [[Canonicalize]] expressions and the original
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,11 @@ object AttributeSet {
val empty = apply(Iterable.empty)

/** Constructs a new [[AttributeSet]] that contains a single [[Attribute]]. */
def apply(a: Attribute): AttributeSet = new AttributeSet(Set(new AttributeEquals(a)))
def apply(a: Attribute): AttributeSet = {
val baseSet = new mutable.LinkedHashSet[AttributeEquals]
baseSet += new AttributeEquals(a)
new AttributeSet(baseSet)
}

/** Constructs a new [[AttributeSet]] given a sequence of [[Expression Expressions]]. */
def apply(baseSet: Iterable[Expression]): AttributeSet = {
Expand All @@ -47,7 +51,7 @@ object AttributeSet {
/** Constructs a new [[AttributeSet]] given a sequence of [[AttributeSet]]s. */
def fromAttributeSets(sets: Iterable[AttributeSet]): AttributeSet = {
val baseSet = sets.foldLeft(new mutable.LinkedHashSet[AttributeEquals]())( _ ++= _.baseSet)
new AttributeSet(baseSet.toSet)
new AttributeSet(baseSet)
}
}

Expand All @@ -62,7 +66,7 @@ object AttributeSet {
* and also makes doing transformations hard (we always try keep older trees instead of new ones
* when the transformation was a no-op).
*/
class AttributeSet private (val baseSet: Set[AttributeEquals])
class AttributeSet private (private val baseSet: mutable.LinkedHashSet[AttributeEquals])
extends Iterable[Attribute] with Serializable {

override def hashCode: Int = baseSet.hashCode()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.optimizer
import scala.collection.mutable

import org.apache.spark.internal.Logging
import org.apache.spark.sql.catalyst.expressions.{And, Attribute, AttributeSet, Expression, PredicateHelper}
import org.apache.spark.sql.catalyst.expressions.{And, Attribute, AttributeSet, Expression, ExpressionSet, PredicateHelper}
import org.apache.spark.sql.catalyst.plans.{Inner, InnerLike, JoinType}
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules.Rule
Expand Down Expand Up @@ -75,18 +75,18 @@ object CostBasedJoinReorder extends Rule[LogicalPlan] with PredicateHelper {
* Extracts items of consecutive inner joins and join conditions.
* This method works for bushy trees and left/right deep trees.
*/
private def extractInnerJoins(plan: LogicalPlan): (Seq[LogicalPlan], Set[Expression]) = {
private def extractInnerJoins(plan: LogicalPlan): (Seq[LogicalPlan], ExpressionSet) = {
plan match {
case Join(left, right, _: InnerLike, Some(cond), JoinHint.NONE) =>
val (leftPlans, leftConditions) = extractInnerJoins(left)
val (rightPlans, rightConditions) = extractInnerJoins(right)
(leftPlans ++ rightPlans, splitConjunctivePredicates(cond).toSet ++
leftConditions ++ rightConditions)
(leftPlans ++ rightPlans, leftConditions ++ rightConditions ++
splitConjunctivePredicates(cond))
case Project(projectList, j @ Join(_, _, _: InnerLike, Some(cond), JoinHint.NONE))
if projectList.forall(_.isInstanceOf[Attribute]) =>
extractInnerJoins(j)
case _ =>
(Seq(plan), Set())
(Seq(plan), ExpressionSet())
}
}

Expand Down Expand Up @@ -143,15 +143,15 @@ object JoinReorderDP extends PredicateHelper with Logging {
def search(
conf: SQLConf,
items: Seq[LogicalPlan],
conditions: Set[Expression],
conditions: ExpressionSet,
output: Seq[Attribute]): LogicalPlan = {

val startTime = System.nanoTime()
// Level i maintains all found plans for i + 1 items.
// Create the initial plans: each plan is a single item with zero cost.
val itemIndex = items.zipWithIndex
val foundPlans = mutable.Buffer[JoinPlanMap](itemIndex.map {
case (item, id) => Set(id) -> JoinPlan(Set(id), item, Set.empty, Cost(0, 0))
case (item, id) => Set(id) -> JoinPlan(Set(id), item, ExpressionSet(), Cost(0, 0))
}.toMap)

// Build filters from the join graph to be used by the search algorithm.
Expand Down Expand Up @@ -194,7 +194,7 @@ object JoinReorderDP extends PredicateHelper with Logging {
private def searchLevel(
existingLevels: Seq[JoinPlanMap],
conf: SQLConf,
conditions: Set[Expression],
conditions: ExpressionSet,
topOutput: AttributeSet,
filters: Option[JoinGraphInfo]): JoinPlanMap = {

Expand Down Expand Up @@ -255,7 +255,7 @@ object JoinReorderDP extends PredicateHelper with Logging {
oneJoinPlan: JoinPlan,
otherJoinPlan: JoinPlan,
conf: SQLConf,
conditions: Set[Expression],
conditions: ExpressionSet,
topOutput: AttributeSet,
filters: Option[JoinGraphInfo]): Option[JoinPlan] = {

Expand Down Expand Up @@ -329,7 +329,7 @@ object JoinReorderDP extends PredicateHelper with Logging {
case class JoinPlan(
itemIds: Set[Int],
plan: LogicalPlan,
joinConds: Set[Expression],
joinConds: ExpressionSet,
planCost: Cost) {

/** Get the cost of the root node of this plan tree. */
Expand Down Expand Up @@ -387,7 +387,7 @@ object JoinReorderDPFilters extends PredicateHelper {
def buildJoinGraphInfo(
conf: SQLConf,
items: Seq[LogicalPlan],
conditions: Set[Expression],
conditions: ExpressionSet,
itemIndex: Seq[(LogicalPlan, Int)]): Option[JoinGraphInfo] = {

if (conf.joinReorderDPStarFilter) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -921,13 +921,13 @@ object InferFiltersFromConstraints extends Rule[LogicalPlan]
private def getAllConstraints(
left: LogicalPlan,
right: LogicalPlan,
conditionOpt: Option[Expression]): Set[Expression] = {
conditionOpt: Option[Expression]): ExpressionSet = {
val baseConstraints = left.constraints.union(right.constraints)
.union(conditionOpt.map(splitConjunctivePredicates).getOrElse(Nil).toSet)
.union(ExpressionSet(conditionOpt.map(splitConjunctivePredicates).getOrElse(Nil)))
baseConstraints.union(inferAdditionalConstraints(baseConstraints))
}

private def inferNewFilter(plan: LogicalPlan, constraints: Set[Expression]): LogicalPlan = {
private def inferNewFilter(plan: LogicalPlan, constraints: ExpressionSet): LogicalPlan = {
val newPredicates = constraints
.union(constructIsNotNullConstraints(constraints, plan.output))
.filter { c =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -169,8 +169,8 @@ abstract class UnaryNode extends LogicalPlan {
* Generates all valid constraints including an set of aliased constraints by replacing the
* original constraint expressions with the corresponding alias
*/
protected def getAllValidConstraints(projectList: Seq[NamedExpression]): Set[Expression] = {
var allConstraints = child.constraints.asInstanceOf[Set[Expression]]
protected def getAllValidConstraints(projectList: Seq[NamedExpression]): ExpressionSet = {
var allConstraints = child.constraints
projectList.foreach {
case a @ Alias(l: Literal, _) =>
allConstraints += EqualNullSafe(a.toAttribute, l)
Expand All @@ -187,7 +187,7 @@ abstract class UnaryNode extends LogicalPlan {
allConstraints
}

override protected lazy val validConstraints: Set[Expression] = child.constraints
override protected lazy val validConstraints: ExpressionSet = child.constraints
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,16 +29,14 @@ trait QueryPlanConstraints extends ConstraintHelper { self: LogicalPlan =>
*/
lazy val constraints: ExpressionSet = {
if (conf.constraintPropagationEnabled) {
ExpressionSet(
validConstraints
.union(inferAdditionalConstraints(validConstraints))
.union(constructIsNotNullConstraints(validConstraints, output))
.filter { c =>
c.references.nonEmpty && c.references.subsetOf(outputSet) && c.deterministic
}
)
validConstraints
.union(inferAdditionalConstraints(validConstraints))
.union(constructIsNotNullConstraints(validConstraints, output))
.filter { c =>
c.references.nonEmpty && c.references.subsetOf(outputSet) && c.deterministic
}
} else {
ExpressionSet(Set.empty)
ExpressionSet()
}
}

Expand All @@ -50,7 +48,7 @@ trait QueryPlanConstraints extends ConstraintHelper { self: LogicalPlan =>
*
* See [[Canonicalize]] for more details.
*/
protected lazy val validConstraints: Set[Expression] = Set.empty
protected lazy val validConstraints: ExpressionSet = ExpressionSet()
}

trait ConstraintHelper {
Expand All @@ -60,8 +58,8 @@ trait ConstraintHelper {
* For e.g., if an operator has constraints of the form (`a = 5`, `a = b`), this returns an
* additional constraint of the form `b = 5`.
*/
def inferAdditionalConstraints(constraints: Set[Expression]): Set[Expression] = {
var inferredConstraints = Set.empty[Expression]
def inferAdditionalConstraints(constraints: ExpressionSet): ExpressionSet = {
var inferredConstraints = ExpressionSet()
// IsNotNull should be constructed by `constructIsNotNullConstraints`.
val predicates = constraints.filterNot(_.isInstanceOf[IsNotNull])
predicates.foreach {
Expand All @@ -79,9 +77,9 @@ trait ConstraintHelper {
}

private def replaceConstraints(
constraints: Set[Expression],
constraints: ExpressionSet,
source: Expression,
destination: Expression): Set[Expression] = constraints.map(_ transform {
destination: Expression): ExpressionSet = constraints.map(_ transform {
case e: Expression if e.semanticEquals(source) => destination
})

Expand All @@ -91,15 +89,15 @@ trait ConstraintHelper {
* returns a constraint of the form `isNotNull(a)`
*/
def constructIsNotNullConstraints(
constraints: Set[Expression],
output: Seq[Attribute]): Set[Expression] = {
constraints: ExpressionSet,
output: Seq[Attribute]): ExpressionSet = {
// First, we propagate constraints from the null intolerant expressions.
var isNotNullConstraints: Set[Expression] = constraints.flatMap(inferIsNotNullConstraints)
var isNotNullConstraints = constraints.flatMap(inferIsNotNullConstraints(_))

// Second, we infer additional constraints from non-nullable attributes that are part of the
// operator's output
val nonNullableAttributes = output.filterNot(_.nullable)
isNotNullConstraints ++= nonNullableAttributes.map(IsNotNull).toSet
isNotNullConstraints ++= nonNullableAttributes.map(IsNotNull)

isNotNullConstraints -- constraints
}
Expand Down
Loading