Skip to content

Commit

Permalink
Move unary node handler to its own utility method
Browse files Browse the repository at this point in the history
  • Loading branch information
bersprockets committed Jan 4, 2025
1 parent a9434ea commit a866ebe
Showing 1 changed file with 49 additions and 44 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -270,9 +270,8 @@ object RewritePredicateSubquery extends Rule[LogicalPlan] with PredicateHelper {
// Handle the case where the left-hand side of an IN-subquery contains an aggregate.
//
// This handler pulls up any expression containing such an IN-subquery into a new Project
// node, replacing aggregate expressions with attributes, and then re-enters
// RewritePredicateSubquery#apply, where the new Project node will be handled
// by the Unary node handler.
// node, replacing aggregate expressions with attributes. The new Project node will be
// handled by the Unary node handler.
//
// The Unary node handler uses the left-hand side of the IN-subquery in a
// join condition. Thus, without this pre-transformation, the join condition
Expand Down Expand Up @@ -454,50 +453,56 @@ object RewritePredicateSubquery extends Rule[LogicalPlan] with PredicateHelper {
}
val newProj = Project(projList, newAggregate)

// Reapply this rule, but now with all interesting expressions
// Call the unary node handler, but now with all interesting expressions
// from Aggregate.aggregateExpressions pulled up into a Project node.
apply(newProj)
handleUnaryNode(newProj)

case u: UnaryNode if u.expressions.exists(
SubqueryExpression.hasInOrCorrelatedExistsSubquery) =>
var newChild = u.child
var introducedAttrs = Seq.empty[Attribute]
val updatedNode = u.mapExpressions(expr => {
val (newExpr, p, newAttrs) = rewriteExistentialExprWithAttrs(Seq(expr), newChild)
newChild = p
introducedAttrs ++= newAttrs
// The newExpr can not be None
newExpr.get
}).withNewChildren(Seq(newChild))
updatedNode match {
case a: Aggregate if conf.getConf(WRAP_EXISTS_IN_AGGREGATE_FUNCTION) =>
// If we have introduced new `exists`-attributes that are referenced by
// aggregateExpressions within a non-aggregateFunction expression, we wrap them in
// first() aggregate function. first() is Spark's executable version of any_value()
// aggregate function.
// We do this to keep the aggregation valid, i.e avoid references outside of aggregate
// functions that are not in grouping expressions.
// Note that the same `exists` attr will never appear in groupingExpressions due to
// PullOutGroupingExpressions rule.
// Also note: the value of `exists` is functionally determined by grouping expressions,
// so applying any aggregate function is semantically safe.
val aggFunctionReferences = a.aggregateExpressions.
flatMap(extractAggregateExpressions).
flatMap(_.references).toSet
val nonAggFuncReferences =
a.aggregateExpressions.flatMap(_.references).filterNot(aggFunctionReferences.contains)
val toBeWrappedExistsAttrs = introducedAttrs.filter(nonAggFuncReferences.contains)

// Replace all eligible `exists` by `First(exists)` among aggregateExpressions.
val newAggregateExpressions = a.aggregateExpressions.map { aggExpr =>
aggExpr.transformUp {
case attr: Attribute if toBeWrappedExistsAttrs.contains(attr) =>
new First(attr).toAggregateExpression()
}.asInstanceOf[NamedExpression]
}
a.copy(aggregateExpressions = newAggregateExpressions)
case _ => updatedNode
}
SubqueryExpression.hasInOrCorrelatedExistsSubquery) => handleUnaryNode(u)
}

/**
* Handle the unary node case
*/
private def handleUnaryNode(u: UnaryNode): LogicalPlan = {
var newChild = u.child
var introducedAttrs = Seq.empty[Attribute]
val updatedNode = u.mapExpressions(expr => {
val (newExpr, p, newAttrs) = rewriteExistentialExprWithAttrs(Seq(expr), newChild)
newChild = p
introducedAttrs ++= newAttrs
// The newExpr can not be None
newExpr.get
}).withNewChildren(Seq(newChild))
updatedNode match {
case a: Aggregate if conf.getConf(WRAP_EXISTS_IN_AGGREGATE_FUNCTION) =>
// If we have introduced new `exists`-attributes that are referenced by
// aggregateExpressions within a non-aggregateFunction expression, we wrap them in
// first() aggregate function. first() is Spark's executable version of any_value()
// aggregate function.
// We do this to keep the aggregation valid, i.e avoid references outside of aggregate
// functions that are not in grouping expressions.
// Note that the same `exists` attr will never appear in groupingExpressions due to
// PullOutGroupingExpressions rule.
// Also note: the value of `exists` is functionally determined by grouping expressions,
// so applying any aggregate function is semantically safe.
val aggFunctionReferences = a.aggregateExpressions.
flatMap(extractAggregateExpressions).
flatMap(_.references).toSet
val nonAggFuncReferences =
a.aggregateExpressions.flatMap(_.references).filterNot(aggFunctionReferences.contains)
val toBeWrappedExistsAttrs = introducedAttrs.filter(nonAggFuncReferences.contains)

// Replace all eligible `exists` by `First(exists)` among aggregateExpressions.
val newAggregateExpressions = a.aggregateExpressions.map { aggExpr =>
aggExpr.transformUp {
case attr: Attribute if toBeWrappedExistsAttrs.contains(attr) =>
new First(attr).toAggregateExpression()
}.asInstanceOf[NamedExpression]
}
a.copy(aggregateExpressions = newAggregateExpressions)
case _ => updatedNode
}
}

/**
Expand Down

0 comments on commit a866ebe

Please sign in to comment.