Skip to content

Commit

Permalink
Respond to review comments
Browse files Browse the repository at this point in the history
  • Loading branch information
bersprockets committed Jan 21, 2025
1 parent 93d98e7 commit cb4066a
Showing 1 changed file with 33 additions and 169 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,11 @@ import org.apache.spark.sql.catalyst.expressions.ScalarSubquery._
import org.apache.spark.sql.catalyst.expressions.SubExprUtils._
import org.apache.spark.sql.catalyst.expressions.aggregate._
import org.apache.spark.sql.catalyst.optimizer.RewriteCorrelatedScalarSubquery.splitSubquery
import org.apache.spark.sql.catalyst.planning.PhysicalAggregation
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules._
import org.apache.spark.sql.catalyst.trees.TreePattern.{EXISTS_SUBQUERY, IN_SUBQUERY, LATERAL_JOIN, LIST_SUBQUERY, PLAN_EXPRESSION, SCALAR_SUBQUERY}
import org.apache.spark.sql.catalyst.util.toPrettySQL
import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.internal.SQLConf.{DECORRELATE_PREDICATE_SUBQUERIES_IN_JOIN_CONDITION, OPTIMIZE_UNCORRELATED_IN_SUBQUERIES_IN_JOIN_CONDITION,
Expand Down Expand Up @@ -269,9 +269,9 @@ object RewritePredicateSubquery extends Rule[LogicalPlan] with PredicateHelper {

// Handle the case where the left-hand side of an IN-subquery contains an aggregate.
//
// This handler pulls up any expression containing such an IN-subquery into a new Project
// node, replacing aggregate expressions with attributes. The new Project node will be
// handled by the Unary node handler.
// If an Aggregate node contains such an IN-subquery, this handler will pull up all
// expressions from the Aggregate node into a new Project node. The new Project node
// will then be handled by the Unary node handler.
//
// The Unary node handler uses the left-hand side of the IN-subquery in a
// join condition. Thus, without this pre-transformation, the join condition
Expand All @@ -281,180 +281,44 @@ object RewritePredicateSubquery extends Rule[LogicalPlan] with PredicateHelper {
//
// For example:
//
// SELECT col1, SUM(col2) IN (SELECT c2 FROM v1) as x
// FROM v2 GROUP BY col1;
// SELECT SUM(col2) IN (SELECT c3 FROM v1) AND SUM(col3) > -1 AS x
// FROM v2;
//
// The above query has this plan on entry to RewritePredicateSubquery#apply:
//
// Aggregate [col1#28], [col1#28, sum(col2#29) IN (list#24 []) AS x#25]
// : +- LocalRelation [c2#35L]
// +- LocalRelation [col1#28, col2#29]
// Aggregate [(sum(col2#18) IN (list#12 []) AND (sum(col3#19) > -1)) AS x#13]
// : +- LocalRelation [c3#28L]
// +- LocalRelation [col2#18, col3#19]
//
// Note that the Aggregate node contains the IN-subquery and the left-hand
// side of the IN-subquery is an aggregate expression (sum(col2#29)).
// side of the IN-subquery is an aggregate expression sum(col2#18)).
//
// This handler transforms the above plan into the following:
// scalastyle:off line.size.limit
//
// Project [col1#28, sum(col2)#36L IN (list#24 []) AS x#25]
// : +- LocalRelation [c2#35L]
// +- Aggregate [col1#28], [col1#28, sum(col2#29) AS sum(col2)#36L]
// +- LocalRelation [col1#28, col2#29]
//
// The transformation pulled the IN-subquery up into a Project. The left-hand side of the
// IN-subquery is now an attribute (sum(col2)#36L) that refers to the actual aggregation
// which is still performed in the Aggregate node (sum(col2#29) AS sum(col2)#36L). The Unary
// node handler will use that attribute in the join condition (rather than the aggregate
// expression).
//
// If the IN-subquery is nested in a larger expression, that entire larger
// expression is pulled up into the Project. For example:
//
// SELECT SUM(col2) IN (SELECT c3 FROM v1) AND SUM(col3) > -1 AS x
// FROM v2;
//
// The input to RewritePredicateSubquery#apply is the following plan:
//
// Aggregate [(sum(col2#34) IN (list#28 []) AND (sum(col3#35) > -1)) AS x#29]
// : +- LocalRelation [c3#44L]
// +- LocalRelation [col2#34, col3#35]
// Project [(_aggregateexpression#20L IN (list#12 []) AND (_aggregateexpression#21L > -1)) AS x#13]
// : +- LocalRelation [c3#28L]
// +- Aggregate [sum(col2#18) AS _aggregateexpression#20L, sum(col3#19) AS _aggregateexpression#21L]
// +- LocalRelation [col2#18, col3#19]
//
// This handler transforms the plan into:
//
// Project [(sum(col2)#45L IN (list#28 []) AND (sum(col3)#46L > -1)) AS x#29]
// : +- LocalRelation [c3#44L]
// +- Aggregate [sum(col2#34) AS sum(col2)#45L, sum(col3#35) AS sum(col3)#46L]
// +- LocalRelation [col2#34, col3#35]
//
// Note that the entire AND expression was pulled up into the Project, but the Aggregate
// node continues to perform the aggregations (but without the IN-subquery expression).
case a: Aggregate if exprsContainsAggregateInSubquery(a.aggregateExpressions) =>
// Find any interesting expressions from Aggregate.aggregateExpressions.
//
// An interesting expression is one that contains an IN-subquery whose left-hand
// operand contains aggregates. For example:
//
// SELECT col1, SUM(col2) IN (SELECT c2 FROM v1)
// FROM v2 GROUP BY col1;
//
// withInSubquery will be a List containing a single Alias expression:
//
// List(sum(col2#12) IN (list#8 []) AS (...)#19)
val withInSubquery = a.aggregateExpressions.filter(exprContainsAggregateInSubquery(_))

// Extract the aggregate expressions from each interesting expression. This will include
// any aggregate expressions that were not part of the IN-subquery but were part
// of the larger containing expression.
val inSubqueryMapping = withInSubquery.map { e =>
(e, extractAggregateExpressions(e))
}

// Map each interesting expression to its contained aggregate expressions.
//
// Example #1:
//
// SELECT col1, SUM(col2) IN (SELECT c2 FROM v1)
// FROM v2 GROUP BY col1;
//
// inSubqueryMap will have a single entry mapping an Alias expression to a Vector
// with a single aggregate expression:
//
// Map(
// sum(col2#100) IN (list []) AS (...)#107 -> Vector(sum(col2#100))
// )
//
// Example #2:
//
// SELECT (SUM(col1), SUM(col2)) IN (SELECT c1, c2 FROM v1)
// FROM v2;
//
// inSubqueryMap will have a single entry mapping an Alias expression to a Vector
// with two aggregate expressions:
//
// Map(
// named_struct(_0, sum(col1#169), _1, sum(col2#170)) IN (list#166 []) AS (...)#179
// -> Vector(sum(col1#169), sum(col2#170))
// )
//
// Example #3:
//
// select SUM(col1) IN (SELECT c1 FROM v1), SUM(col2) IN (SELECT c2 FROM v1)
// FROM v2;
//
// inSubqueryMap will have two entries, each mapping an Alias expression to a Vector
// with a single aggregate expression:
//
// Map(
// sum(col1#193) IN (list#189 []) AS (...)#207 -> Vector(sum(col1#193)),
// sum(col2#194) IN (list#190 []) AS (...)#208 -> Vector(sum(col2#194))
// )
//
// Example #5:
//
// SELECT SUM(col2) IN (SELECT c3 FROM v1) AND SUM(col3) > -1 AS x
// FROM v2;
//
// inSubqueryMap will contain a single AND expression that maps to two aggregate
// expressions, even though only one of those aggregate expressions is used as
// the left-hand operand of the IN-subquery expression.
//
// Map(
// (sum(col2#34) IN (list#28 []) AND (sum(col3#35) > -1)) AS x#29
// -> Vector(sum(col2#34), sum(col3#35))
// )
//
// The keys of inSubqueryMap will be used to determine which expressions in
// the old Aggregate node are interesting. The values of inSubqueryMap, after
// being wrapped in Alias expressions, will replace their associated interesting
// expressions in a new Aggregate node.
val inSubqueryMap = inSubqueryMapping.toMap

// Get all aggregate expressions associated with interesting expressions.
val aggregateExprs = inSubqueryMapping.flatMap(_._2)
// Create aliases for each above aggregate expression. We can't use the aggregate
// expressions directly in the new Aggregate node because Aggregate.aggregateExpressions
// has the type Seq[NamedExpression].
val aggregateExprAliases = aggregateExprs.map(a => Alias(a, toPrettySQL(a))())
// Create a mapping from each aggregate expression to its alias.
val aggregateExprAliasMap = aggregateExprs.zip(aggregateExprAliases).toMap
// Create attributes from those aliases of aggregate expressions. These attributes
// will be used in the new Project node to refer to the aliased aggregate expressions
// in the new Aggregate node.
val aggregateExprAttrs = aggregateExprAliases.map(_.toAttribute)
// Create a mapping from aggregate expressions to attributes. This will be
// used when patching the interesting expressions after they are pulled up
// into the new Project node: aggregate expressions will be replaced by attributes.
val aggregateExprAttrMap = aggregateExprs.zip(aggregateExprAttrs).toMap

// Create an Aggregate node without the interesting expressions, just
// the associated aggregate expressions plus any other group-by or aggregate expressions
// that were not involved in the interesting expressions.
val newAggregateExpressions = a.aggregateExpressions.flatMap {
// If this expression contains IN-subqueries with aggregates in the left-hand
// operand, replace with just the aggregates.
case ae: Expression if inSubqueryMap.contains(ae) =>
// Replace the expression with an aliased aggregate expression.
inSubqueryMap(ae).map(aggregateExprAliasMap(_))
case ae => Seq(ae)
}
val newAggregate = a.copy(aggregateExpressions = newAggregateExpressions)

// Create a projection with the IN-subquery expressions that contain aggregates, replacing
// the aggregate expressions with attribute references to the output of the new Aggregate
// operator. Also include the other output of the Aggregate operator.
val projList = a.aggregateExpressions.map {
// If this expression contains an IN-subquery that uses an aggregate, we
// need to do something special
case ae: Expression if inSubqueryMap.contains(ae) =>
ae.transform {
// Patch any aggregate expression with its corresponding attribute.
case a: AggregateExpression => aggregateExprAttrMap(a)
}.asInstanceOf[NamedExpression]
case ae => ae.toAttribute
}
val newProj = Project(projList, newAggregate)

// Call the unary node handler, but now with all interesting expressions
// from Aggregate.aggregateExpressions pulled up into a Project node.
// scalastyle:on
// Note that both the IN-subquery and the greater-than expressions have been
// pulled up into the Project node. These expressions use attributes
// (_aggregateexpression#20L and _aggregateexpression#21L) to refer to the aggregations
// which are still performed in the Aggregate node (sum(col2#18) and sum(col3#19)).
case p @ PhysicalAggregation(
groupingExpressions, aggregateExpressions, resultExpressions, child)
if exprsContainsAggregateInSubquery(p.expressions) =>
val aggExprs = aggregateExpressions.map(
ae => Alias(ae, "_aggregateexpression")(ae.resultId))
val aggExprIds = aggExprs.map(_.exprId).toSet
val resExprs = resultExpressions.map(_.transform {
case a: AttributeReference if aggExprIds.contains(a.exprId) =>
a.withName("_aggregateexpression")
}.asInstanceOf[NamedExpression])
// Rewrite the projection and the aggregate separately and then piece them together.
val newAgg = Aggregate(groupingExpressions, groupingExpressions ++ aggExprs, child)
val newProj = Project(resExprs, newAgg)
handleUnaryNode(newProj)

case u: UnaryNode if u.expressions.exists(
Expand Down

0 comments on commit cb4066a

Please sign in to comment.