From a866ebe19b7602aa7849ec9aba876eb2fad3ac6a Mon Sep 17 00:00:00 2001 From: Bruce Robbins Date: Fri, 3 Jan 2025 17:53:14 -0800 Subject: [PATCH] Move unary node handler to its own utility method --- .../sql/catalyst/optimizer/subquery.scala | 93 ++++++++++--------- 1 file changed, 49 insertions(+), 44 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala index 9c5317084b2b6..6ee11009ba943 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala @@ -270,9 +270,8 @@ object RewritePredicateSubquery extends Rule[LogicalPlan] with PredicateHelper { // Handle the case where the left-hand side of an IN-subquery contains an aggregate. // // This handler pulls up any expression containing such an IN-subquery into a new Project - // node, replacing aggregate expressions with attributes, and then re-enters - // RewritePredicateSubquery#apply, where the new Project node will be handled - // by the Unary node handler. + // node, replacing aggregate expressions with attributes. The new Project node will be + // handled by the Unary node handler. // // The Unary node handler uses the left-hand side of the IN-subquery in a // join condition. Thus, without this pre-transformation, the join condition @@ -454,50 +453,56 @@ object RewritePredicateSubquery extends Rule[LogicalPlan] with PredicateHelper { } val newProj = Project(projList, newAggregate) - // Reapply this rule, but now with all interesting expressions + // Call the unary node handler, but now with all interesting expressions // from Aggregate.aggregateExpressions pulled up into a Project node. - apply(newProj) + handleUnaryNode(newProj) case u: UnaryNode if u.expressions.exists( - SubqueryExpression.hasInOrCorrelatedExistsSubquery) => - var newChild = u.child - var introducedAttrs = Seq.empty[Attribute] - val updatedNode = u.mapExpressions(expr => { - val (newExpr, p, newAttrs) = rewriteExistentialExprWithAttrs(Seq(expr), newChild) - newChild = p - introducedAttrs ++= newAttrs - // The newExpr can not be None - newExpr.get - }).withNewChildren(Seq(newChild)) - updatedNode match { - case a: Aggregate if conf.getConf(WRAP_EXISTS_IN_AGGREGATE_FUNCTION) => - // If we have introduced new `exists`-attributes that are referenced by - // aggregateExpressions within a non-aggregateFunction expression, we wrap them in - // first() aggregate function. first() is Spark's executable version of any_value() - // aggregate function. - // We do this to keep the aggregation valid, i.e avoid references outside of aggregate - // functions that are not in grouping expressions. - // Note that the same `exists` attr will never appear in groupingExpressions due to - // PullOutGroupingExpressions rule. - // Also note: the value of `exists` is functionally determined by grouping expressions, - // so applying any aggregate function is semantically safe. - val aggFunctionReferences = a.aggregateExpressions. - flatMap(extractAggregateExpressions). - flatMap(_.references).toSet - val nonAggFuncReferences = - a.aggregateExpressions.flatMap(_.references).filterNot(aggFunctionReferences.contains) - val toBeWrappedExistsAttrs = introducedAttrs.filter(nonAggFuncReferences.contains) - - // Replace all eligible `exists` by `First(exists)` among aggregateExpressions. - val newAggregateExpressions = a.aggregateExpressions.map { aggExpr => - aggExpr.transformUp { - case attr: Attribute if toBeWrappedExistsAttrs.contains(attr) => - new First(attr).toAggregateExpression() - }.asInstanceOf[NamedExpression] - } - a.copy(aggregateExpressions = newAggregateExpressions) - case _ => updatedNode - } + SubqueryExpression.hasInOrCorrelatedExistsSubquery) => handleUnaryNode(u) + } + + /** + * Handle the unary node case + */ + private def handleUnaryNode(u: UnaryNode): LogicalPlan = { + var newChild = u.child + var introducedAttrs = Seq.empty[Attribute] + val updatedNode = u.mapExpressions(expr => { + val (newExpr, p, newAttrs) = rewriteExistentialExprWithAttrs(Seq(expr), newChild) + newChild = p + introducedAttrs ++= newAttrs + // The newExpr can not be None + newExpr.get + }).withNewChildren(Seq(newChild)) + updatedNode match { + case a: Aggregate if conf.getConf(WRAP_EXISTS_IN_AGGREGATE_FUNCTION) => + // If we have introduced new `exists`-attributes that are referenced by + // aggregateExpressions within a non-aggregateFunction expression, we wrap them in + // first() aggregate function. first() is Spark's executable version of any_value() + // aggregate function. + // We do this to keep the aggregation valid, i.e avoid references outside of aggregate + // functions that are not in grouping expressions. + // Note that the same `exists` attr will never appear in groupingExpressions due to + // PullOutGroupingExpressions rule. + // Also note: the value of `exists` is functionally determined by grouping expressions, + // so applying any aggregate function is semantically safe. + val aggFunctionReferences = a.aggregateExpressions. + flatMap(extractAggregateExpressions). + flatMap(_.references).toSet + val nonAggFuncReferences = + a.aggregateExpressions.flatMap(_.references).filterNot(aggFunctionReferences.contains) + val toBeWrappedExistsAttrs = introducedAttrs.filter(nonAggFuncReferences.contains) + + // Replace all eligible `exists` by `First(exists)` among aggregateExpressions. + val newAggregateExpressions = a.aggregateExpressions.map { aggExpr => + aggExpr.transformUp { + case attr: Attribute if toBeWrappedExistsAttrs.contains(attr) => + new First(attr).toAggregateExpression() + }.asInstanceOf[NamedExpression] + } + a.copy(aggregateExpressions = newAggregateExpressions) + case _ => updatedNode + } } /**