diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala index 5a4e9f37c3951..378081221c8c1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala @@ -27,6 +27,7 @@ import org.apache.spark.sql.catalyst.expressions.ScalarSubquery._ import org.apache.spark.sql.catalyst.expressions.SubExprUtils._ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.optimizer.RewriteCorrelatedScalarSubquery.splitSubquery +import org.apache.spark.sql.catalyst.planning.PhysicalAggregation import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules._ @@ -115,6 +116,26 @@ object RewritePredicateSubquery extends Rule[LogicalPlan] with PredicateHelper { } } + def exprsContainsAggregateInSubquery(exprs: Seq[Expression]): Boolean = { + exprs.exists { expr => + exprContainsAggregateInSubquery(expr) + } + } + + def exprContainsAggregateInSubquery(expr: Expression): Boolean = { + expr.exists { + case InSubquery(values, _) => + values.exists { v => + v.exists { + case _: AggregateExpression => true + case _ => false + } + } + case _ => false; + } + } + + def apply(plan: LogicalPlan): LogicalPlan = plan.transformWithPruning( _.containsAnyPattern(EXISTS_SUBQUERY, LIST_SUBQUERY)) { case Filter(condition, child) @@ -246,46 +267,106 @@ object RewritePredicateSubquery extends Rule[LogicalPlan] with PredicateHelper { } } + // Handle the case where the left-hand side of an IN-subquery contains an aggregate. + // + // If an Aggregate node contains such an IN-subquery, this handler will pull up all + // expressions from the Aggregate node into a new Project node. The new Project node + // will then be handled by the Unary node handler. + // + // The Unary node handler uses the left-hand side of the IN-subquery in a + // join condition. Thus, without this pre-transformation, the join condition + // contains an aggregate, which is illegal. With this pre-transformation, the + // join condition contains an attribute from the left-hand side of the + // IN-subquery contained in the Project node. + // + // For example: + // + // SELECT SUM(col2) IN (SELECT c3 FROM v1) AND SUM(col3) > -1 AS x + // FROM v2; + // + // The above query has this plan on entry to RewritePredicateSubquery#apply: + // + // Aggregate [(sum(col2#18) IN (list#12 []) AND (sum(col3#19) > -1)) AS x#13] + // : +- LocalRelation [c3#28L] + // +- LocalRelation [col2#18, col3#19] + // + // Note that the Aggregate node contains the IN-subquery and the left-hand + // side of the IN-subquery is an aggregate expression sum(col2#18)). + // + // This handler transforms the above plan into the following: + // scalastyle:off line.size.limit + // + // Project [(_aggregateexpression#20L IN (list#12 []) AND (_aggregateexpression#21L > -1)) AS x#13] + // : +- LocalRelation [c3#28L] + // +- Aggregate [sum(col2#18) AS _aggregateexpression#20L, sum(col3#19) AS _aggregateexpression#21L] + // +- LocalRelation [col2#18, col3#19] + // + // scalastyle:on + // Note that both the IN-subquery and the greater-than expressions have been + // pulled up into the Project node. These expressions use attributes + // (_aggregateexpression#20L and _aggregateexpression#21L) to refer to the aggregations + // which are still performed in the Aggregate node (sum(col2#18) and sum(col3#19)). + case p @ PhysicalAggregation( + groupingExpressions, aggregateExpressions, resultExpressions, child) + if exprsContainsAggregateInSubquery(p.expressions) => + val aggExprs = aggregateExpressions.map( + ae => Alias(ae, "_aggregateexpression")(ae.resultId)) + val aggExprIds = aggExprs.map(_.exprId).toSet + val resExprs = resultExpressions.map(_.transform { + case a: AttributeReference if aggExprIds.contains(a.exprId) => + a.withName("_aggregateexpression") + }.asInstanceOf[NamedExpression]) + // Rewrite the projection and the aggregate separately and then piece them together. + val newAgg = Aggregate(groupingExpressions, groupingExpressions ++ aggExprs, child) + val newProj = Project(resExprs, newAgg) + handleUnaryNode(newProj) + case u: UnaryNode if u.expressions.exists( - SubqueryExpression.hasInOrCorrelatedExistsSubquery) => - var newChild = u.child - var introducedAttrs = Seq.empty[Attribute] - val updatedNode = u.mapExpressions(expr => { - val (newExpr, p, newAttrs) = rewriteExistentialExprWithAttrs(Seq(expr), newChild) - newChild = p - introducedAttrs ++= newAttrs - // The newExpr can not be None - newExpr.get - }).withNewChildren(Seq(newChild)) - updatedNode match { - case a: Aggregate if conf.getConf(WRAP_EXISTS_IN_AGGREGATE_FUNCTION) => - // If we have introduced new `exists`-attributes that are referenced by - // aggregateExpressions within a non-aggregateFunction expression, we wrap them in - // first() aggregate function. first() is Spark's executable version of any_value() - // aggregate function. - // We do this to keep the aggregation valid, i.e avoid references outside of aggregate - // functions that are not in grouping expressions. - // Note that the same `exists` attr will never appear in groupingExpressions due to - // PullOutGroupingExpressions rule. - // Also note: the value of `exists` is functionally determined by grouping expressions, - // so applying any aggregate function is semantically safe. - val aggFunctionReferences = a.aggregateExpressions. - flatMap(extractAggregateExpressions). - flatMap(_.references).toSet - val nonAggFuncReferences = - a.aggregateExpressions.flatMap(_.references).filterNot(aggFunctionReferences.contains) - val toBeWrappedExistsAttrs = introducedAttrs.filter(nonAggFuncReferences.contains) - - // Replace all eligible `exists` by `First(exists)` among aggregateExpressions. - val newAggregateExpressions = a.aggregateExpressions.map { aggExpr => - aggExpr.transformUp { - case attr: Attribute if toBeWrappedExistsAttrs.contains(attr) => - new First(attr).toAggregateExpression() - }.asInstanceOf[NamedExpression] - } - a.copy(aggregateExpressions = newAggregateExpressions) - case _ => updatedNode - } + SubqueryExpression.hasInOrCorrelatedExistsSubquery) => handleUnaryNode(u) + } + + /** + * Handle the unary node case + */ + private def handleUnaryNode(u: UnaryNode): LogicalPlan = { + var newChild = u.child + var introducedAttrs = Seq.empty[Attribute] + val updatedNode = u.mapExpressions(expr => { + val (newExpr, p, newAttrs) = rewriteExistentialExprWithAttrs(Seq(expr), newChild) + newChild = p + introducedAttrs ++= newAttrs + // The newExpr can not be None + newExpr.get + }).withNewChildren(Seq(newChild)) + updatedNode match { + case a: Aggregate if conf.getConf(WRAP_EXISTS_IN_AGGREGATE_FUNCTION) => + // If we have introduced new `exists`-attributes that are referenced by + // aggregateExpressions within a non-aggregateFunction expression, we wrap them in + // first() aggregate function. first() is Spark's executable version of any_value() + // aggregate function. + // We do this to keep the aggregation valid, i.e avoid references outside of aggregate + // functions that are not in grouping expressions. + // Note that the same `exists` attr will never appear in groupingExpressions due to + // PullOutGroupingExpressions rule. + // Also note: the value of `exists` is functionally determined by grouping expressions, + // so applying any aggregate function is semantically safe. + val aggFunctionReferences = a.aggregateExpressions. + flatMap(extractAggregateExpressions). + flatMap(_.references).toSet + val nonAggFuncReferences = + a.aggregateExpressions.flatMap(_.references).filterNot(aggFunctionReferences.contains) + val toBeWrappedExistsAttrs = introducedAttrs.filter(nonAggFuncReferences.contains) + + // Replace all eligible `exists` by `First(exists)` among aggregateExpressions. + val newAggregateExpressions = a.aggregateExpressions.map { aggExpr => + aggExpr.transformUp { + case attr: Attribute if toBeWrappedExistsAttrs.contains(attr) => + new First(attr).toAggregateExpression() + }.asInstanceOf[NamedExpression] + } + a.copy(aggregateExpressions = newAggregateExpressions) + case _ => updatedNode + } } /** diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteSubquerySuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteSubquerySuite.scala index 17547bbcb9402..c45a761353c85 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteSubquerySuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteSubquerySuite.scala @@ -20,10 +20,11 @@ package org.apache.spark.sql.catalyst.optimizer import org.apache.spark.sql.catalyst.QueryPlanningTracker import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ -import org.apache.spark.sql.catalyst.expressions.{IsNull, ListQuery, Not} +import org.apache.spark.sql.catalyst.expressions.{Cast, IsNull, ListQuery, Not} import org.apache.spark.sql.catalyst.plans.{ExistenceJoin, LeftSemi, PlanTest} import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan} import org.apache.spark.sql.catalyst.rules.RuleExecutor +import org.apache.spark.sql.types.LongType class RewriteSubquerySuite extends PlanTest { @@ -79,4 +80,20 @@ class RewriteSubquerySuite extends PlanTest { Optimize.executeAndTrack(query.analyze, tracker) assert(tracker.rules(RewritePredicateSubquery.ruleName).numEffectiveInvocations == 0) } + + test("SPARK-50091: Don't put aggregate expression in join condition") { + val relation1 = LocalRelation($"c1".int, $"c2".int, $"c3".int) + val relation2 = LocalRelation($"col1".int, $"col2".int, $"col3".int) + val plan = relation2.groupBy()(sum($"col2").in(ListQuery(relation1.select($"c3")))) + val optimized = Optimize.execute(plan.analyze) + val aggregate = relation2 + .select($"col2") + .groupBy()(sum($"col2").as("_aggregateexpression")) + val correctAnswer = aggregate + .join(relation1.select(Cast($"c3", LongType).as("c3")), + ExistenceJoin($"exists".boolean.withNullability(false)), + Some($"_aggregateexpression" === $"c3")) + .select($"exists".as("(sum(col2) IN (listquery()))")).analyze + comparePlans(optimized, correctAnswer) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala index 9e97c224736d8..e7e41f6570d3c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala @@ -2800,4 +2800,34 @@ class SubquerySuite extends QueryTest checkAnswer(df3, Row(7)) } } + + test("SPARK-50091: Handle aggregates in left-hand operand of IN-subquery") { + withView("v1", "v2") { + Seq((1, 2, 2), (1, 5, 3), (2, 0, 4), (3, 7, 7), (3, 8, 8)) + .toDF("c1", "c2", "c3") + .createOrReplaceTempView("v1") + Seq((1, 2, 2), (1, 3, 3), (2, 2, 4), (3, 7, 7), (3, 1, 1)) + .toDF("col1", "col2", "col3") + .createOrReplaceTempView("v2") + + val df1 = sql("SELECT col1, SUM(col2) IN (SELECT c3 FROM v1) FROM v2 GROUP BY col1") + checkAnswer(df1, + Row(1, false) :: Row(2, true) :: Row(3, true) :: Nil) + + val df2 = sql("""SELECT + | col1, + | SUM(col2) IN (SELECT c3 FROM v1) and SUM(col3) IN (SELECT c2 FROM v1) AS x + |FROM v2 GROUP BY col1 + |ORDER BY col1""".stripMargin) + checkAnswer(df2, + Row(1, false) :: Row(2, false) :: Row(3, true) :: Nil) + + val df3 = sql("""SELECT col1, (SUM(col2), SUM(col3)) IN (SELECT c3, c2 FROM v1) AS x + |FROM v2 + |GROUP BY col1 + |ORDER BY col1""".stripMargin) + checkAnswer(df3, + Row(1, false) :: Row(2, false) :: Row(3, true) :: Nil) + } + } }