From cb4066a057655b3a8546ffb5c5f0b98de4685c8a Mon Sep 17 00:00:00 2001 From: Bruce Robbins Date: Thu, 16 Jan 2025 16:03:09 -0800 Subject: [PATCH] Respond to review comments --- .../sql/catalyst/optimizer/subquery.scala | 202 +++--------------- 1 file changed, 33 insertions(+), 169 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala index 6ee11009ba943..378081221c8c1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala @@ -27,11 +27,11 @@ import org.apache.spark.sql.catalyst.expressions.ScalarSubquery._ import org.apache.spark.sql.catalyst.expressions.SubExprUtils._ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.optimizer.RewriteCorrelatedScalarSubquery.splitSubquery +import org.apache.spark.sql.catalyst.planning.PhysicalAggregation import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules._ import org.apache.spark.sql.catalyst.trees.TreePattern.{EXISTS_SUBQUERY, IN_SUBQUERY, LATERAL_JOIN, LIST_SUBQUERY, PLAN_EXPRESSION, SCALAR_SUBQUERY} -import org.apache.spark.sql.catalyst.util.toPrettySQL import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.internal.SQLConf.{DECORRELATE_PREDICATE_SUBQUERIES_IN_JOIN_CONDITION, OPTIMIZE_UNCORRELATED_IN_SUBQUERIES_IN_JOIN_CONDITION, @@ -269,9 +269,9 @@ object RewritePredicateSubquery extends Rule[LogicalPlan] with PredicateHelper { // Handle the case where the left-hand side of an IN-subquery contains an aggregate. // - // This handler pulls up any expression containing such an IN-subquery into a new Project - // node, replacing aggregate expressions with attributes. The new Project node will be - // handled by the Unary node handler. + // If an Aggregate node contains such an IN-subquery, this handler will pull up all + // expressions from the Aggregate node into a new Project node. The new Project node + // will then be handled by the Unary node handler. // // The Unary node handler uses the left-hand side of the IN-subquery in a // join condition. Thus, without this pre-transformation, the join condition @@ -281,180 +281,44 @@ object RewritePredicateSubquery extends Rule[LogicalPlan] with PredicateHelper { // // For example: // - // SELECT col1, SUM(col2) IN (SELECT c2 FROM v1) as x - // FROM v2 GROUP BY col1; + // SELECT SUM(col2) IN (SELECT c3 FROM v1) AND SUM(col3) > -1 AS x + // FROM v2; // // The above query has this plan on entry to RewritePredicateSubquery#apply: // - // Aggregate [col1#28], [col1#28, sum(col2#29) IN (list#24 []) AS x#25] - // : +- LocalRelation [c2#35L] - // +- LocalRelation [col1#28, col2#29] + // Aggregate [(sum(col2#18) IN (list#12 []) AND (sum(col3#19) > -1)) AS x#13] + // : +- LocalRelation [c3#28L] + // +- LocalRelation [col2#18, col3#19] // // Note that the Aggregate node contains the IN-subquery and the left-hand - // side of the IN-subquery is an aggregate expression (sum(col2#29)). + // side of the IN-subquery is an aggregate expression sum(col2#18)). // // This handler transforms the above plan into the following: + // scalastyle:off line.size.limit // - // Project [col1#28, sum(col2)#36L IN (list#24 []) AS x#25] - // : +- LocalRelation [c2#35L] - // +- Aggregate [col1#28], [col1#28, sum(col2#29) AS sum(col2)#36L] - // +- LocalRelation [col1#28, col2#29] - // - // The transformation pulled the IN-subquery up into a Project. The left-hand side of the - // IN-subquery is now an attribute (sum(col2)#36L) that refers to the actual aggregation - // which is still performed in the Aggregate node (sum(col2#29) AS sum(col2)#36L). The Unary - // node handler will use that attribute in the join condition (rather than the aggregate - // expression). - // - // If the IN-subquery is nested in a larger expression, that entire larger - // expression is pulled up into the Project. For example: - // - // SELECT SUM(col2) IN (SELECT c3 FROM v1) AND SUM(col3) > -1 AS x - // FROM v2; - // - // The input to RewritePredicateSubquery#apply is the following plan: - // - // Aggregate [(sum(col2#34) IN (list#28 []) AND (sum(col3#35) > -1)) AS x#29] - // : +- LocalRelation [c3#44L] - // +- LocalRelation [col2#34, col3#35] + // Project [(_aggregateexpression#20L IN (list#12 []) AND (_aggregateexpression#21L > -1)) AS x#13] + // : +- LocalRelation [c3#28L] + // +- Aggregate [sum(col2#18) AS _aggregateexpression#20L, sum(col3#19) AS _aggregateexpression#21L] + // +- LocalRelation [col2#18, col3#19] // - // This handler transforms the plan into: - // - // Project [(sum(col2)#45L IN (list#28 []) AND (sum(col3)#46L > -1)) AS x#29] - // : +- LocalRelation [c3#44L] - // +- Aggregate [sum(col2#34) AS sum(col2)#45L, sum(col3#35) AS sum(col3)#46L] - // +- LocalRelation [col2#34, col3#35] - // - // Note that the entire AND expression was pulled up into the Project, but the Aggregate - // node continues to perform the aggregations (but without the IN-subquery expression). - case a: Aggregate if exprsContainsAggregateInSubquery(a.aggregateExpressions) => - // Find any interesting expressions from Aggregate.aggregateExpressions. - // - // An interesting expression is one that contains an IN-subquery whose left-hand - // operand contains aggregates. For example: - // - // SELECT col1, SUM(col2) IN (SELECT c2 FROM v1) - // FROM v2 GROUP BY col1; - // - // withInSubquery will be a List containing a single Alias expression: - // - // List(sum(col2#12) IN (list#8 []) AS (...)#19) - val withInSubquery = a.aggregateExpressions.filter(exprContainsAggregateInSubquery(_)) - - // Extract the aggregate expressions from each interesting expression. This will include - // any aggregate expressions that were not part of the IN-subquery but were part - // of the larger containing expression. - val inSubqueryMapping = withInSubquery.map { e => - (e, extractAggregateExpressions(e)) - } - - // Map each interesting expression to its contained aggregate expressions. - // - // Example #1: - // - // SELECT col1, SUM(col2) IN (SELECT c2 FROM v1) - // FROM v2 GROUP BY col1; - // - // inSubqueryMap will have a single entry mapping an Alias expression to a Vector - // with a single aggregate expression: - // - // Map( - // sum(col2#100) IN (list []) AS (...)#107 -> Vector(sum(col2#100)) - // ) - // - // Example #2: - // - // SELECT (SUM(col1), SUM(col2)) IN (SELECT c1, c2 FROM v1) - // FROM v2; - // - // inSubqueryMap will have a single entry mapping an Alias expression to a Vector - // with two aggregate expressions: - // - // Map( - // named_struct(_0, sum(col1#169), _1, sum(col2#170)) IN (list#166 []) AS (...)#179 - // -> Vector(sum(col1#169), sum(col2#170)) - // ) - // - // Example #3: - // - // select SUM(col1) IN (SELECT c1 FROM v1), SUM(col2) IN (SELECT c2 FROM v1) - // FROM v2; - // - // inSubqueryMap will have two entries, each mapping an Alias expression to a Vector - // with a single aggregate expression: - // - // Map( - // sum(col1#193) IN (list#189 []) AS (...)#207 -> Vector(sum(col1#193)), - // sum(col2#194) IN (list#190 []) AS (...)#208 -> Vector(sum(col2#194)) - // ) - // - // Example #5: - // - // SELECT SUM(col2) IN (SELECT c3 FROM v1) AND SUM(col3) > -1 AS x - // FROM v2; - // - // inSubqueryMap will contain a single AND expression that maps to two aggregate - // expressions, even though only one of those aggregate expressions is used as - // the left-hand operand of the IN-subquery expression. - // - // Map( - // (sum(col2#34) IN (list#28 []) AND (sum(col3#35) > -1)) AS x#29 - // -> Vector(sum(col2#34), sum(col3#35)) - // ) - // - // The keys of inSubqueryMap will be used to determine which expressions in - // the old Aggregate node are interesting. The values of inSubqueryMap, after - // being wrapped in Alias expressions, will replace their associated interesting - // expressions in a new Aggregate node. - val inSubqueryMap = inSubqueryMapping.toMap - - // Get all aggregate expressions associated with interesting expressions. - val aggregateExprs = inSubqueryMapping.flatMap(_._2) - // Create aliases for each above aggregate expression. We can't use the aggregate - // expressions directly in the new Aggregate node because Aggregate.aggregateExpressions - // has the type Seq[NamedExpression]. - val aggregateExprAliases = aggregateExprs.map(a => Alias(a, toPrettySQL(a))()) - // Create a mapping from each aggregate expression to its alias. - val aggregateExprAliasMap = aggregateExprs.zip(aggregateExprAliases).toMap - // Create attributes from those aliases of aggregate expressions. These attributes - // will be used in the new Project node to refer to the aliased aggregate expressions - // in the new Aggregate node. - val aggregateExprAttrs = aggregateExprAliases.map(_.toAttribute) - // Create a mapping from aggregate expressions to attributes. This will be - // used when patching the interesting expressions after they are pulled up - // into the new Project node: aggregate expressions will be replaced by attributes. - val aggregateExprAttrMap = aggregateExprs.zip(aggregateExprAttrs).toMap - - // Create an Aggregate node without the interesting expressions, just - // the associated aggregate expressions plus any other group-by or aggregate expressions - // that were not involved in the interesting expressions. - val newAggregateExpressions = a.aggregateExpressions.flatMap { - // If this expression contains IN-subqueries with aggregates in the left-hand - // operand, replace with just the aggregates. - case ae: Expression if inSubqueryMap.contains(ae) => - // Replace the expression with an aliased aggregate expression. - inSubqueryMap(ae).map(aggregateExprAliasMap(_)) - case ae => Seq(ae) - } - val newAggregate = a.copy(aggregateExpressions = newAggregateExpressions) - - // Create a projection with the IN-subquery expressions that contain aggregates, replacing - // the aggregate expressions with attribute references to the output of the new Aggregate - // operator. Also include the other output of the Aggregate operator. - val projList = a.aggregateExpressions.map { - // If this expression contains an IN-subquery that uses an aggregate, we - // need to do something special - case ae: Expression if inSubqueryMap.contains(ae) => - ae.transform { - // Patch any aggregate expression with its corresponding attribute. - case a: AggregateExpression => aggregateExprAttrMap(a) - }.asInstanceOf[NamedExpression] - case ae => ae.toAttribute - } - val newProj = Project(projList, newAggregate) - - // Call the unary node handler, but now with all interesting expressions - // from Aggregate.aggregateExpressions pulled up into a Project node. + // scalastyle:on + // Note that both the IN-subquery and the greater-than expressions have been + // pulled up into the Project node. These expressions use attributes + // (_aggregateexpression#20L and _aggregateexpression#21L) to refer to the aggregations + // which are still performed in the Aggregate node (sum(col2#18) and sum(col3#19)). + case p @ PhysicalAggregation( + groupingExpressions, aggregateExpressions, resultExpressions, child) + if exprsContainsAggregateInSubquery(p.expressions) => + val aggExprs = aggregateExpressions.map( + ae => Alias(ae, "_aggregateexpression")(ae.resultId)) + val aggExprIds = aggExprs.map(_.exprId).toSet + val resExprs = resultExpressions.map(_.transform { + case a: AttributeReference if aggExprIds.contains(a.exprId) => + a.withName("_aggregateexpression") + }.asInstanceOf[NamedExpression]) + // Rewrite the projection and the aggregate separately and then piece them together. + val newAgg = Aggregate(groupingExpressions, groupingExpressions ++ aggExprs, child) + val newProj = Project(resExprs, newAgg) handleUnaryNode(newProj) case u: UnaryNode if u.expressions.exists(