Skip to content

Commit

Permalink
Review updates
Browse files Browse the repository at this point in the history
  • Loading branch information
bersprockets committed Nov 24, 2024
1 parent 3319192 commit 6573627
Showing 1 changed file with 155 additions and 19 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -266,55 +266,191 @@ object RewritePredicateSubquery extends Rule[LogicalPlan] with PredicateHelper {
condition = Some(newCondition)))
}
}

// Handle the case where the left-hand side of an IN-subquery contains an aggregate.
//
// This handler pulls up any expression containing such an IN-subquery into a new Project
// node and then re-enters RewritePredicateSubquery#apply, where the new Project node
// will be handled by the Unary node handler. The Unary node handler will transform the
// plan into a join. Without this pre-transformation, the Unary node handler would
// create a join with an aggregate expression in the join condition, which is illegal
// (see SPARK-50091).
//
// For example:
//
// SELECT col1, SUM(col2) IN (SELECT c2 FROM v1) as x
// FROM v2 GROUP BY col1;
//
// The above query has this plan on entry to RewritePredicateSubquery#apply:
//
// Aggregate [col1#28], [col1#28, sum(col2#29) IN (list#24 []) AS x#25]
// : +- LocalRelation [c2#35L]
// +- LocalRelation [col1#28, col2#29]
//
// Note that the Aggregate node contains the IN-subquery and the left-hand
// side of the IN-subquery is an aggregate expression (sum(col2#28)).
//
// This handler transforms the above plan into the following:
//
// Project [col1#28, sum(col2)#36L IN (list#24 []) AS x#25]
// : +- LocalRelation [c2#35L]
// +- Aggregate [col1#28], [col1#28, sum(col2#29) AS sum(col2)#36L]
// +- LocalRelation [col1#28, col2#29]
//
// The transformation pulled the IN-subquery up into a Project. The left-hand side of the
// In-subquery is now an attribute (sum(col2)#36L) that refers to the actual aggregation
// which is still performed in the Aggregate node (sum(col2#28) AS sum(col2)#36L).
//
// Note that if the IN-subquery is nested in a larger expression, that entire larger
// expression is pulled up into the Project. For example:
//
// SELECT SUM(col2) IN (SELECT c3 FROM v1) AND SUM(col3) > -1 AS x
// FROM v2;
//
// The input to RewritePredicateSubquery#apply is the following plan:
//
// Aggregate [(sum(col2#34) IN (list#28 []) AND (sum(col3#35) > -1)) AS x#29]
// : +- LocalRelation [c3#44L]
// +- LocalRelation [col2#34, col3#35]
//
// This handler transforms the plan into:
//
// Project [(sum(col2)#45L IN (list#28 []) AND (sum(col3)#46L > -1)) AS x#29]
// : +- LocalRelation [c3#44L]
// +- Aggregate [sum(col2#34) AS sum(col2)#45L, sum(col3#35) AS sum(col3)#46L]
// +- LocalRelation [col2#34, col3#35]
//
// Note that the entire AND expression was pulled up into the Project, but the Aggregate
// node continues to perform the aggregations (but without the IN-subquery expression).
case a: Aggregate if exprsContainsAggregateInSubquery(a.aggregateExpressions) =>
// find expressions with an IN-subquery whose left-hand operand contains aggregates
// Find any interesting expressions from Aggregate.aggregateExpressions.
//
// An interesting expression is one that contains an IN-subquery whose left-hand
// operand contains aggregates. For example:
//
// SELECT col1, SUM(col2) IN (SELECT c2 FROM v1)
// FROM v2 GROUP BY col1;
//
// withInSubquery will be a List containing a single Alias expression:
//
// List(sum(col2#12) IN (list#8 []) AS (...)#19)
val withInSubquery = a.aggregateExpressions.filter(exprContainsAggregateInSubquery(_))

// extract the aggregate expressions from withInSubquery
// Extract the aggregate expressions from each interesting expression. This will include
// any aggregate expressions that were not part of the IN-subquery but were part
// of the larger containing expression.
val inSubqueryMapping = withInSubquery.map { e =>
(e, extractAggregateExpressions(e))
}

// Map each interesting expression to its contained aggregate expressions.
//
// Example #1:
//
// SELECT col1, SUM(col2) IN (SELECT c2 FROM v1)
// FROM v2 GROUP BY col1;
//
// inSubqueryMap will have a single entry mapping an Alias expression to a Vector
// with a single aggregate expression:
//
// Map(
// sum(col2#100) IN (list []) AS (...)#107 -> Vector(sum(col2#100))
// )
//
// Example #2:
//
// SELECT (SUM(col1), SUM(col2)) IN (SELECT c1, c2 FROM v1)
// FROM v2;
//
// inSubqueryMap will have a single entry mapping an Alias expression to a Vector
// with two aggregate expressions:
//
// Map(
// named_struct(_0, sum(col1#169), _1, sum(col2#170)) IN (list#166 []) AS (...)#179
// -> Vector(sum(col1#169), sum(col2#170))
// )
//
// Example #3:
//
// select SUM(col1) IN (SELECT c1 FROM v1), SUM(col2) IN (SELECT c2 FROM v1)
// FROM v2;
//
// inSubqueryMap will have two entries, each mapping an Alias expression to a Vector
// with a single aggregate expression:
//
// Map(
// sum(col1#193) IN (list#189 []) AS (...)#207 -> Vector(sum(col1#193)),
// sum(col2#194) IN (list#190 []) AS (...)#208 -> Vector(sum(col2#194))
// )
//
// Example #5:
//
// SELECT SUM(col2) IN (SELECT c3 FROM v1) AND SUM(col3) > -1 AS x
// FROM v2;
//
// inSubqueryMap will contain a single AND expression that maps to two aggregate
// expressions, even though only one of those aggregate expressions is used as
// the left-hand operand of the IN-subquery expression.
//
// Map(
// (sum(col2#34) IN (list#28 []) AND (sum(col3#35) > -1)) AS x#29
// -> Vector(sum(col2#34), sum(col3#35))
// )
//
// The keys of inSubqueryMap will be used to determine which expressions in
// the old Aggregate node are interesting. The values of inSubqueryMap, after
// being wrapped in Alias expressions, will replace their associated interesting
// expressions in a new Aggregate node.
val inSubqueryMap = inSubqueryMapping.toMap
// get all aggregate expressions found in left-hand operands of IN-subqueries

// Get all aggregate expressions associated with interesting expressions.
val aggregateExprs = inSubqueryMapping.flatMap(_._2)
// create aliases for each above aggregate expression
// Create aliases for each above aggregate expression. We can't use the aggregate
// expressions directly in the new Aggregate node because Aggregate.aggregateExpressions
// has the type Seq[NamedExpression].
val aggregateExprAliases = aggregateExprs.map(a => Alias(a, toPrettySQL(a))())
// create a mapping from each aggregate expression to its alias
// Create a mapping from each aggregate expression to its alias.
val aggregateExprAliasMap = aggregateExprs.zip(aggregateExprAliases).toMap
// create attributes from those aliases of aggregate expressions
// Create attributes from those aliases of aggregate expressions. These attributes
// will be used in the new Project node to refer to the aliased aggregate expressions
// in the new Aggregate node.
val aggregateExprAttrs = aggregateExprAliases.map(_.toAttribute)
// create a mapping from aggregate expressions to attributes
// Create a mapping from aggregate expressions to attributes. This will be
// used when patching the interesting expressions after they are pulled up
// into the new Project node: aggregate expressions will be replaced by attributes.
val aggregateExprAttrMap = aggregateExprs.zip(aggregateExprAttrs).toMap

// create an Aggregate node without the offending IN-subqueries, just
// the aggregates themselves and all the other aggregate expressions.
// Create an Aggregate node without the interesting expressions, just
// the associated aggregate expressions plus any other group-by or aggregate expressions
// that were not involved in the interesting expressions.
val newAggregateExpressions = a.aggregateExpressions.flatMap {
// if this expression contains IN-subqueries with aggregates in the left-hand
// operand, replace with just the aggregates
// If this expression contains IN-subqueries with aggregates in the left-hand
// operand, replace with just the aggregates.
case ae: Expression if inSubqueryMap.contains(ae) =>
// replace the expression with an aliased aggregate expression
// Replace the expression with an aliased aggregate expression.
inSubqueryMap(ae).map(aggregateExprAliasMap(_))
case ae @ _ => Seq(ae)
case ae => Seq(ae)
}
val newAggregate = a.copy(aggregateExpressions = newAggregateExpressions)

// Create a projection with the IN-subquery expressions that contain aggregates, replacing
// the aggregate expressions with attribute references to the output of the Aggregate
// the aggregate expressions with attribute references to the output of the new Aggregate
// operator. Also include the other output of the Aggregate operator.
val projList = a.aggregateExpressions.map {
// if this expression contains an IN-subquery that uses an aggregate, we
// If this expression contains an IN-subquery that uses an aggregate, we
// need to do something special
case ae: Expression if inSubqueryMap.contains(ae) =>
ae.transform {
// patch any aggregate expression with its corresponding attribute
// Patch any aggregate expression with its corresponding attribute.
case a: AggregateExpression => aggregateExprAttrMap(a)
}.asInstanceOf[NamedExpression]
case ae @ _ => ae.toAttribute
case ae => ae.toAttribute
}
val newProj = Project(projList, newAggregate)

// reapply this rule, now with a Project as parent to the Aggregate
apply(Project(projList, newAggregate))
// Reapply this rule, but now with all interesting expressions
// from Aggregate.aggregateExpressions pulled up into a Project node.
apply(newProj)

case u: UnaryNode if u.expressions.exists(
SubqueryExpression.hasInOrCorrelatedExistsSubquery) =>
Expand Down

0 comments on commit 6573627

Please sign in to comment.