From 6573627d9c069bcbf5d02de12ec2e6d36150d018 Mon Sep 17 00:00:00 2001 From: Bruce Robbins Date: Sun, 24 Nov 2024 13:03:42 -0800 Subject: [PATCH] Review updates --- .../sql/catalyst/optimizer/subquery.scala | 174 ++++++++++++++++-- 1 file changed, 155 insertions(+), 19 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala index 68e8c7afe2257..898ca03989d45 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala @@ -266,55 +266,191 @@ object RewritePredicateSubquery extends Rule[LogicalPlan] with PredicateHelper { condition = Some(newCondition))) } } + + // Handle the case where the left-hand side of an IN-subquery contains an aggregate. + // + // This handler pulls up any expression containing such an IN-subquery into a new Project + // node and then re-enters RewritePredicateSubquery#apply, where the new Project node + // will be handled by the Unary node handler. The Unary node handler will transform the + // plan into a join. Without this pre-transformation, the Unary node handler would + // create a join with an aggregate expression in the join condition, which is illegal + // (see SPARK-50091). + // + // For example: + // + // SELECT col1, SUM(col2) IN (SELECT c2 FROM v1) as x + // FROM v2 GROUP BY col1; + // + // The above query has this plan on entry to RewritePredicateSubquery#apply: + // + // Aggregate [col1#28], [col1#28, sum(col2#29) IN (list#24 []) AS x#25] + // : +- LocalRelation [c2#35L] + // +- LocalRelation [col1#28, col2#29] + // + // Note that the Aggregate node contains the IN-subquery and the left-hand + // side of the IN-subquery is an aggregate expression (sum(col2#28)). + // + // This handler transforms the above plan into the following: + // + // Project [col1#28, sum(col2)#36L IN (list#24 []) AS x#25] + // : +- LocalRelation [c2#35L] + // +- Aggregate [col1#28], [col1#28, sum(col2#29) AS sum(col2)#36L] + // +- LocalRelation [col1#28, col2#29] + // + // The transformation pulled the IN-subquery up into a Project. The left-hand side of the + // In-subquery is now an attribute (sum(col2)#36L) that refers to the actual aggregation + // which is still performed in the Aggregate node (sum(col2#28) AS sum(col2)#36L). + // + // Note that if the IN-subquery is nested in a larger expression, that entire larger + // expression is pulled up into the Project. For example: + // + // SELECT SUM(col2) IN (SELECT c3 FROM v1) AND SUM(col3) > -1 AS x + // FROM v2; + // + // The input to RewritePredicateSubquery#apply is the following plan: + // + // Aggregate [(sum(col2#34) IN (list#28 []) AND (sum(col3#35) > -1)) AS x#29] + // : +- LocalRelation [c3#44L] + // +- LocalRelation [col2#34, col3#35] + // + // This handler transforms the plan into: + // + // Project [(sum(col2)#45L IN (list#28 []) AND (sum(col3)#46L > -1)) AS x#29] + // : +- LocalRelation [c3#44L] + // +- Aggregate [sum(col2#34) AS sum(col2)#45L, sum(col3#35) AS sum(col3)#46L] + // +- LocalRelation [col2#34, col3#35] + // + // Note that the entire AND expression was pulled up into the Project, but the Aggregate + // node continues to perform the aggregations (but without the IN-subquery expression). case a: Aggregate if exprsContainsAggregateInSubquery(a.aggregateExpressions) => - // find expressions with an IN-subquery whose left-hand operand contains aggregates + // Find any interesting expressions from Aggregate.aggregateExpressions. + // + // An interesting expression is one that contains an IN-subquery whose left-hand + // operand contains aggregates. For example: + // + // SELECT col1, SUM(col2) IN (SELECT c2 FROM v1) + // FROM v2 GROUP BY col1; + // + // withInSubquery will be a List containing a single Alias expression: + // + // List(sum(col2#12) IN (list#8 []) AS (...)#19) val withInSubquery = a.aggregateExpressions.filter(exprContainsAggregateInSubquery(_)) - // extract the aggregate expressions from withInSubquery + // Extract the aggregate expressions from each interesting expression. This will include + // any aggregate expressions that were not part of the IN-subquery but were part + // of the larger containing expression. val inSubqueryMapping = withInSubquery.map { e => (e, extractAggregateExpressions(e)) } + // Map each interesting expression to its contained aggregate expressions. + // + // Example #1: + // + // SELECT col1, SUM(col2) IN (SELECT c2 FROM v1) + // FROM v2 GROUP BY col1; + // + // inSubqueryMap will have a single entry mapping an Alias expression to a Vector + // with a single aggregate expression: + // + // Map( + // sum(col2#100) IN (list []) AS (...)#107 -> Vector(sum(col2#100)) + // ) + // + // Example #2: + // + // SELECT (SUM(col1), SUM(col2)) IN (SELECT c1, c2 FROM v1) + // FROM v2; + // + // inSubqueryMap will have a single entry mapping an Alias expression to a Vector + // with two aggregate expressions: + // + // Map( + // named_struct(_0, sum(col1#169), _1, sum(col2#170)) IN (list#166 []) AS (...)#179 + // -> Vector(sum(col1#169), sum(col2#170)) + // ) + // + // Example #3: + // + // select SUM(col1) IN (SELECT c1 FROM v1), SUM(col2) IN (SELECT c2 FROM v1) + // FROM v2; + // + // inSubqueryMap will have two entries, each mapping an Alias expression to a Vector + // with a single aggregate expression: + // + // Map( + // sum(col1#193) IN (list#189 []) AS (...)#207 -> Vector(sum(col1#193)), + // sum(col2#194) IN (list#190 []) AS (...)#208 -> Vector(sum(col2#194)) + // ) + // + // Example #5: + // + // SELECT SUM(col2) IN (SELECT c3 FROM v1) AND SUM(col3) > -1 AS x + // FROM v2; + // + // inSubqueryMap will contain a single AND expression that maps to two aggregate + // expressions, even though only one of those aggregate expressions is used as + // the left-hand operand of the IN-subquery expression. + // + // Map( + // (sum(col2#34) IN (list#28 []) AND (sum(col3#35) > -1)) AS x#29 + // -> Vector(sum(col2#34), sum(col3#35)) + // ) + // + // The keys of inSubqueryMap will be used to determine which expressions in + // the old Aggregate node are interesting. The values of inSubqueryMap, after + // being wrapped in Alias expressions, will replace their associated interesting + // expressions in a new Aggregate node. val inSubqueryMap = inSubqueryMapping.toMap - // get all aggregate expressions found in left-hand operands of IN-subqueries + + // Get all aggregate expressions associated with interesting expressions. val aggregateExprs = inSubqueryMapping.flatMap(_._2) - // create aliases for each above aggregate expression + // Create aliases for each above aggregate expression. We can't use the aggregate + // expressions directly in the new Aggregate node because Aggregate.aggregateExpressions + // has the type Seq[NamedExpression]. val aggregateExprAliases = aggregateExprs.map(a => Alias(a, toPrettySQL(a))()) - // create a mapping from each aggregate expression to its alias + // Create a mapping from each aggregate expression to its alias. val aggregateExprAliasMap = aggregateExprs.zip(aggregateExprAliases).toMap - // create attributes from those aliases of aggregate expressions + // Create attributes from those aliases of aggregate expressions. These attributes + // will be used in the new Project node to refer to the aliased aggregate expressions + // in the new Aggregate node. val aggregateExprAttrs = aggregateExprAliases.map(_.toAttribute) - // create a mapping from aggregate expressions to attributes + // Create a mapping from aggregate expressions to attributes. This will be + // used when patching the interesting expressions after they are pulled up + // into the new Project node: aggregate expressions will be replaced by attributes. val aggregateExprAttrMap = aggregateExprs.zip(aggregateExprAttrs).toMap - // create an Aggregate node without the offending IN-subqueries, just - // the aggregates themselves and all the other aggregate expressions. + // Create an Aggregate node without the interesting expressions, just + // the associated aggregate expressions plus any other group-by or aggregate expressions + // that were not involved in the interesting expressions. val newAggregateExpressions = a.aggregateExpressions.flatMap { - // if this expression contains IN-subqueries with aggregates in the left-hand - // operand, replace with just the aggregates + // If this expression contains IN-subqueries with aggregates in the left-hand + // operand, replace with just the aggregates. case ae: Expression if inSubqueryMap.contains(ae) => - // replace the expression with an aliased aggregate expression + // Replace the expression with an aliased aggregate expression. inSubqueryMap(ae).map(aggregateExprAliasMap(_)) - case ae @ _ => Seq(ae) + case ae => Seq(ae) } val newAggregate = a.copy(aggregateExpressions = newAggregateExpressions) // Create a projection with the IN-subquery expressions that contain aggregates, replacing - // the aggregate expressions with attribute references to the output of the Aggregate + // the aggregate expressions with attribute references to the output of the new Aggregate // operator. Also include the other output of the Aggregate operator. val projList = a.aggregateExpressions.map { - // if this expression contains an IN-subquery that uses an aggregate, we + // If this expression contains an IN-subquery that uses an aggregate, we // need to do something special case ae: Expression if inSubqueryMap.contains(ae) => ae.transform { - // patch any aggregate expression with its corresponding attribute + // Patch any aggregate expression with its corresponding attribute. case a: AggregateExpression => aggregateExprAttrMap(a) }.asInstanceOf[NamedExpression] - case ae @ _ => ae.toAttribute + case ae => ae.toAttribute } + val newProj = Project(projList, newAggregate) - // reapply this rule, now with a Project as parent to the Aggregate - apply(Project(projList, newAggregate)) + // Reapply this rule, but now with all interesting expressions + // from Aggregate.aggregateExpressions pulled up into a Project node. + apply(newProj) case u: UnaryNode if u.expressions.exists( SubqueryExpression.hasInOrCorrelatedExistsSubquery) =>