Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-50091][SQL] Handle case of aggregates in left-hand operand of IN-subquery #48627

Closed
wants to merge 17 commits into from
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ import org.apache.spark.sql.catalyst.expressions.ScalarSubquery._
import org.apache.spark.sql.catalyst.expressions.SubExprUtils._
import org.apache.spark.sql.catalyst.expressions.aggregate._
import org.apache.spark.sql.catalyst.optimizer.RewriteCorrelatedScalarSubquery.splitSubquery
import org.apache.spark.sql.catalyst.planning.PhysicalAggregation
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules._
Expand Down Expand Up @@ -115,6 +116,26 @@ object RewritePredicateSubquery extends Rule[LogicalPlan] with PredicateHelper {
}
}

def exprsContainsAggregateInSubquery(exprs: Seq[Expression]): Boolean = {
exprs.exists { expr =>
exprContainsAggregateInSubquery(expr)
}
}

def exprContainsAggregateInSubquery(expr: Expression): Boolean = {
expr.exists {
case InSubquery(values, _) =>
values.exists { v =>
v.exists {
case _: AggregateExpression => true
case _ => false
}
}
case _ => false;
}
}


def apply(plan: LogicalPlan): LogicalPlan = plan.transformWithPruning(
_.containsAnyPattern(EXISTS_SUBQUERY, LIST_SUBQUERY)) {
case Filter(condition, child)
Expand Down Expand Up @@ -246,46 +267,106 @@ object RewritePredicateSubquery extends Rule[LogicalPlan] with PredicateHelper {
}
}

// Handle the case where the left-hand side of an IN-subquery contains an aggregate.
//
// If an Aggregate node contains such an IN-subquery, this handler will pull up all
// expressions from the Aggregate node into a new Project node. The new Project node
// will then be handled by the Unary node handler.
//
// The Unary node handler uses the left-hand side of the IN-subquery in a
// join condition. Thus, without this pre-transformation, the join condition
// contains an aggregate, which is illegal. With this pre-transformation, the
// join condition contains an attribute from the left-hand side of the
// IN-subquery contained in the Project node.
//
// For example:
//
// SELECT SUM(col2) IN (SELECT c3 FROM v1) AND SUM(col3) > -1 AS x
// FROM v2;
//
// The above query has this plan on entry to RewritePredicateSubquery#apply:
//
// Aggregate [(sum(col2#18) IN (list#12 []) AND (sum(col3#19) > -1)) AS x#13]
// : +- LocalRelation [c3#28L]
// +- LocalRelation [col2#18, col3#19]
//
// Note that the Aggregate node contains the IN-subquery and the left-hand
// side of the IN-subquery is an aggregate expression sum(col2#18)).
//
// This handler transforms the above plan into the following:
// scalastyle:off line.size.limit
//
// Project [(_aggregateexpression#20L IN (list#12 []) AND (_aggregateexpression#21L > -1)) AS x#13]
// : +- LocalRelation [c3#28L]
// +- Aggregate [sum(col2#18) AS _aggregateexpression#20L, sum(col3#19) AS _aggregateexpression#21L]
// +- LocalRelation [col2#18, col3#19]
//
// scalastyle:on
// Note that both the IN-subquery and the greater-than expressions have been
// pulled up into the Project node. These expressions use attributes
// (_aggregateexpression#20L and _aggregateexpression#21L) to refer to the aggregations
// which are still performed in the Aggregate node (sum(col2#18) and sum(col3#19)).
case p @ PhysicalAggregation(
groupingExpressions, aggregateExpressions, resultExpressions, child)
if exprsContainsAggregateInSubquery(p.expressions) =>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if exprsContainsAggregateInSubquery(p.expressions) =>
if exprsContainsAggregateInSubquery(resultExpressions) =>

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This rewrite only pulls out subquery expressions for Aggregate#aggregateExpressions, not grouping expressions.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Re: if exprsContainsAggregateInSubquery(resultExpressions) =>.

That won't work withexprsContainsAggregateInSubquery as it currently stands, since that function looks for in-subqueries with aggregate expressions in the left-hand operand. resultExpressions has the aggregate expressions replaced with attributes, so exprsContainsAggregateInSubquery would never trigger.

Alternatively, I could do

if exprsContainsAggregateInSubquery(p.asInstanceOf[Aggregate].aggregateExpressions) =>

which is kind of ugly, but does the trick.

Another alternative: I'm the only one calling exprsContainsAggregateInSubquery, so I could change it to return true if there are any in-subqueries at all with no regard to characteristics of the left-hand operand. We would end up rewriting some cases that wouldn't otherwise cause trouble.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah OK, let's keep it as it is

val aggExprs = aggregateExpressions.map(
ae => Alias(ae, "_aggregateexpression")(ae.resultId))
val aggExprIds = aggExprs.map(_.exprId).toSet
val resExprs = resultExpressions.map(_.transform {
case a: AttributeReference if aggExprIds.contains(a.exprId) =>
a.withName("_aggregateexpression")
}.asInstanceOf[NamedExpression])
// Rewrite the projection and the aggregate separately and then piece them together.
val newAgg = Aggregate(groupingExpressions, groupingExpressions ++ aggExprs, child)
val newProj = Project(resExprs, newAgg)
handleUnaryNode(newProj)

case u: UnaryNode if u.expressions.exists(
SubqueryExpression.hasInOrCorrelatedExistsSubquery) =>
var newChild = u.child
var introducedAttrs = Seq.empty[Attribute]
val updatedNode = u.mapExpressions(expr => {
val (newExpr, p, newAttrs) = rewriteExistentialExprWithAttrs(Seq(expr), newChild)
newChild = p
introducedAttrs ++= newAttrs
// The newExpr can not be None
newExpr.get
}).withNewChildren(Seq(newChild))
updatedNode match {
case a: Aggregate if conf.getConf(WRAP_EXISTS_IN_AGGREGATE_FUNCTION) =>
// If we have introduced new `exists`-attributes that are referenced by
// aggregateExpressions within a non-aggregateFunction expression, we wrap them in
// first() aggregate function. first() is Spark's executable version of any_value()
// aggregate function.
// We do this to keep the aggregation valid, i.e avoid references outside of aggregate
// functions that are not in grouping expressions.
// Note that the same `exists` attr will never appear in groupingExpressions due to
// PullOutGroupingExpressions rule.
// Also note: the value of `exists` is functionally determined by grouping expressions,
// so applying any aggregate function is semantically safe.
val aggFunctionReferences = a.aggregateExpressions.
flatMap(extractAggregateExpressions).
flatMap(_.references).toSet
val nonAggFuncReferences =
a.aggregateExpressions.flatMap(_.references).filterNot(aggFunctionReferences.contains)
val toBeWrappedExistsAttrs = introducedAttrs.filter(nonAggFuncReferences.contains)

// Replace all eligible `exists` by `First(exists)` among aggregateExpressions.
val newAggregateExpressions = a.aggregateExpressions.map { aggExpr =>
aggExpr.transformUp {
case attr: Attribute if toBeWrappedExistsAttrs.contains(attr) =>
new First(attr).toAggregateExpression()
}.asInstanceOf[NamedExpression]
}
a.copy(aggregateExpressions = newAggregateExpressions)
case _ => updatedNode
}
SubqueryExpression.hasInOrCorrelatedExistsSubquery) => handleUnaryNode(u)
}

/**
* Handle the unary node case
*/
private def handleUnaryNode(u: UnaryNode): LogicalPlan = {
var newChild = u.child
var introducedAttrs = Seq.empty[Attribute]
val updatedNode = u.mapExpressions(expr => {
val (newExpr, p, newAttrs) = rewriteExistentialExprWithAttrs(Seq(expr), newChild)
newChild = p
introducedAttrs ++= newAttrs
// The newExpr can not be None
newExpr.get
}).withNewChildren(Seq(newChild))
updatedNode match {
case a: Aggregate if conf.getConf(WRAP_EXISTS_IN_AGGREGATE_FUNCTION) =>
// If we have introduced new `exists`-attributes that are referenced by
// aggregateExpressions within a non-aggregateFunction expression, we wrap them in
// first() aggregate function. first() is Spark's executable version of any_value()
// aggregate function.
// We do this to keep the aggregation valid, i.e avoid references outside of aggregate
// functions that are not in grouping expressions.
// Note that the same `exists` attr will never appear in groupingExpressions due to
// PullOutGroupingExpressions rule.
// Also note: the value of `exists` is functionally determined by grouping expressions,
// so applying any aggregate function is semantically safe.
val aggFunctionReferences = a.aggregateExpressions.
flatMap(extractAggregateExpressions).
flatMap(_.references).toSet
val nonAggFuncReferences =
a.aggregateExpressions.flatMap(_.references).filterNot(aggFunctionReferences.contains)
val toBeWrappedExistsAttrs = introducedAttrs.filter(nonAggFuncReferences.contains)

// Replace all eligible `exists` by `First(exists)` among aggregateExpressions.
val newAggregateExpressions = a.aggregateExpressions.map { aggExpr =>
aggExpr.transformUp {
case attr: Attribute if toBeWrappedExistsAttrs.contains(attr) =>
new First(attr).toAggregateExpression()
}.asInstanceOf[NamedExpression]
}
a.copy(aggregateExpressions = newAggregateExpressions)
case _ => updatedNode
}
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,11 @@ package org.apache.spark.sql.catalyst.optimizer
import org.apache.spark.sql.catalyst.QueryPlanningTracker
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.expressions.{IsNull, ListQuery, Not}
import org.apache.spark.sql.catalyst.expressions.{Cast, IsNull, ListQuery, Not}
import org.apache.spark.sql.catalyst.plans.{ExistenceJoin, LeftSemi, PlanTest}
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan}
import org.apache.spark.sql.catalyst.rules.RuleExecutor
import org.apache.spark.sql.types.LongType


class RewriteSubquerySuite extends PlanTest {
Expand Down Expand Up @@ -79,4 +80,20 @@ class RewriteSubquerySuite extends PlanTest {
Optimize.executeAndTrack(query.analyze, tracker)
assert(tracker.rules(RewritePredicateSubquery.ruleName).numEffectiveInvocations == 0)
}

test("SPARK-50091: Don't put aggregate expression in join condition") {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I also updated this test to check the whole optimized plan rather than simply testing that the join condition does not have an aggregate expression.

val relation1 = LocalRelation($"c1".int, $"c2".int, $"c3".int)
val relation2 = LocalRelation($"col1".int, $"col2".int, $"col3".int)
val plan = relation2.groupBy()(sum($"col2").in(ListQuery(relation1.select($"c3"))))
val optimized = Optimize.execute(plan.analyze)
val aggregate = relation2
.select($"col2")
.groupBy()(sum($"col2").as("_aggregateexpression"))
val correctAnswer = aggregate
.join(relation1.select(Cast($"c3", LongType).as("c3")),
ExistenceJoin($"exists".boolean.withNullability(false)),
Some($"_aggregateexpression" === $"c3"))
.select($"exists".as("(sum(col2) IN (listquery()))")).analyze
comparePlans(optimized, correctAnswer)
}
}
30 changes: 30 additions & 0 deletions sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2800,4 +2800,34 @@ class SubquerySuite extends QueryTest
checkAnswer(df3, Row(7))
}
}

test("SPARK-50091: Handle aggregates in left-hand operand of IN-subquery") {
withView("v1", "v2") {
Seq((1, 2, 2), (1, 5, 3), (2, 0, 4), (3, 7, 7), (3, 8, 8))
.toDF("c1", "c2", "c3")
.createOrReplaceTempView("v1")
Seq((1, 2, 2), (1, 3, 3), (2, 2, 4), (3, 7, 7), (3, 1, 1))
.toDF("col1", "col2", "col3")
.createOrReplaceTempView("v2")

val df1 = sql("SELECT col1, SUM(col2) IN (SELECT c3 FROM v1) FROM v2 GROUP BY col1")
checkAnswer(df1,
Row(1, false) :: Row(2, true) :: Row(3, true) :: Nil)

val df2 = sql("""SELECT
| col1,
| SUM(col2) IN (SELECT c3 FROM v1) and SUM(col3) IN (SELECT c2 FROM v1) AS x
|FROM v2 GROUP BY col1
|ORDER BY col1""".stripMargin)
checkAnswer(df2,
Row(1, false) :: Row(2, false) :: Row(3, true) :: Nil)

val df3 = sql("""SELECT col1, (SUM(col2), SUM(col3)) IN (SELECT c3, c2 FROM v1) AS x
|FROM v2
|GROUP BY col1
|ORDER BY col1""".stripMargin)
checkAnswer(df3,
Row(1, false) :: Row(2, false) :: Row(3, true) :: Nil)
}
}
}
Loading