Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-50091][SQL][3.5] Handle case of aggregates in left-hand operand of IN-subquery #49663

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import org.apache.spark.sql.catalyst.expressions.ScalarSubquery._
import org.apache.spark.sql.catalyst.expressions.SubExprUtils._
import org.apache.spark.sql.catalyst.expressions.aggregate._
import org.apache.spark.sql.catalyst.optimizer.RewriteCorrelatedScalarSubquery.splitSubquery
import org.apache.spark.sql.catalyst.planning.PhysicalAggregation
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules._
Expand Down Expand Up @@ -100,6 +101,25 @@ object RewritePredicateSubquery extends Rule[LogicalPlan] with PredicateHelper {
}
}

def exprsContainsAggregateInSubquery(exprs: Seq[Expression]): Boolean = {
exprs.exists { expr =>
exprContainsAggregateInSubquery(expr)
}
}

def exprContainsAggregateInSubquery(expr: Expression): Boolean = {
expr.exists {
case InSubquery(values, _) =>
values.exists { v =>
v.exists {
case _: AggregateExpression => true
case _ => false
}
}
case _ => false;
}
}

def apply(plan: LogicalPlan): LogicalPlan = plan.transformWithPruning(
_.containsAnyPattern(EXISTS_SUBQUERY, LIST_SUBQUERY)) {
case Filter(condition, child)
Expand Down Expand Up @@ -162,15 +182,75 @@ object RewritePredicateSubquery extends Rule[LogicalPlan] with PredicateHelper {
Project(p.output, Filter(newCond.get, inputPlan))
}

// Handle the case where the left-hand side of an IN-subquery contains an aggregate.
//
// If an Aggregate node contains such an IN-subquery, this handler will pull up all
// expressions from the Aggregate node into a new Project node. The new Project node
// will then be handled by the Unary node handler.
//
// The Unary node handler uses the left-hand side of the IN-subquery in a
// join condition. Thus, without this pre-transformation, the join condition
// contains an aggregate, which is illegal. With this pre-transformation, the
// join condition contains an attribute from the left-hand side of the
// IN-subquery contained in the Project node.
//
// For example:
//
// SELECT SUM(col2) IN (SELECT c3 FROM v1) AND SUM(col3) > -1 AS x
// FROM v2;
//
// The above query has this plan on entry to RewritePredicateSubquery#apply:
//
// Aggregate [(sum(col2#18) IN (list#12 []) AND (sum(col3#19) > -1)) AS x#13]
// : +- LocalRelation [c3#28L]
// +- LocalRelation [col2#18, col3#19]
//
// Note that the Aggregate node contains the IN-subquery and the left-hand
// side of the IN-subquery is an aggregate expression sum(col2#18)).
//
// This handler transforms the above plan into the following:
// scalastyle:off line.size.limit
//
// Project [(_aggregateexpression#20L IN (list#12 []) AND (_aggregateexpression#21L > -1)) AS x#13]
// : +- LocalRelation [c3#28L]
// +- Aggregate [sum(col2#18) AS _aggregateexpression#20L, sum(col3#19) AS _aggregateexpression#21L]
// +- LocalRelation [col2#18, col3#19]
//
// scalastyle:on
// Note that both the IN-subquery and the greater-than expressions have been
// pulled up into the Project node. These expressions use attributes
// (_aggregateexpression#20L and _aggregateexpression#21L) to refer to the aggregations
// which are still performed in the Aggregate node (sum(col2#18) and sum(col3#19)).
case p @ PhysicalAggregation(
groupingExpressions, aggregateExpressions, resultExpressions, child)
if exprsContainsAggregateInSubquery(p.expressions) =>
val aggExprs = aggregateExpressions.map(
ae => Alias(ae, "_aggregateexpression")(ae.resultId))
val aggExprIds = aggExprs.map(_.exprId).toSet
val resExprs = resultExpressions.map(_.transform {
case a: AttributeReference if aggExprIds.contains(a.exprId) =>
a.withName("_aggregateexpression")
}.asInstanceOf[NamedExpression])
// Rewrite the projection and the aggregate separately and then piece them together.
val newAgg = Aggregate(groupingExpressions, groupingExpressions ++ aggExprs, child)
val newProj = Project(resExprs, newAgg)
handleUnaryNode(newProj)

case u: UnaryNode if u.expressions.exists(
SubqueryExpression.hasInOrCorrelatedExistsSubquery) =>
var newChild = u.child
u.mapExpressions(expr => {
val (newExpr, p) = rewriteExistentialExpr(Seq(expr), newChild)
newChild = p
// The newExpr can not be None
newExpr.get
}).withNewChildren(Seq(newChild))
SubqueryExpression.hasInOrCorrelatedExistsSubquery) => handleUnaryNode(u)
}

/**
* Handle the unary node case
*/
private def handleUnaryNode(u: UnaryNode): LogicalPlan = {
var newChild = u.child
u.mapExpressions(expr => {
val (newExpr, p) = rewriteExistentialExpr(Seq(expr), newChild)
newChild = p
// The newExpr can not be None
newExpr.get
}).withNewChildren(Seq(newChild))
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,11 @@ package org.apache.spark.sql.catalyst.optimizer
import org.apache.spark.sql.catalyst.QueryPlanningTracker
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.expressions.{IsNull, ListQuery, Not}
import org.apache.spark.sql.catalyst.expressions.{Cast, IsNull, ListQuery, Not}
import org.apache.spark.sql.catalyst.plans.{ExistenceJoin, LeftSemi, PlanTest}
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan}
import org.apache.spark.sql.catalyst.rules.RuleExecutor
import org.apache.spark.sql.types.LongType


class RewriteSubquerySuite extends PlanTest {
Expand Down Expand Up @@ -79,4 +80,20 @@ class RewriteSubquerySuite extends PlanTest {
Optimize.executeAndTrack(query.analyze, tracker)
assert(tracker.rules(RewritePredicateSubquery.ruleName).numEffectiveInvocations == 0)
}

test("SPARK-50091: Don't put aggregate expression in join condition") {
val relation1 = LocalRelation($"c1".int, $"c2".int, $"c3".int)
val relation2 = LocalRelation($"col1".int, $"col2".int, $"col3".int)
val plan = relation2.groupBy()(sum($"col2").in(ListQuery(relation1.select($"c3"))))
val optimized = Optimize.execute(plan.analyze)
val aggregate = relation2
.select($"col2")
.groupBy()(sum($"col2").as("_aggregateexpression"))
val correctAnswer = aggregate
.join(relation1.select(Cast($"c3", LongType).as("c3")),
ExistenceJoin($"exists".boolean.withNullability(false)),
Some($"_aggregateexpression" === $"c3"))
.select($"exists".as("(sum(col2) IN (listquery()))")).analyze
comparePlans(optimized, correctAnswer)
}
}
30 changes: 30 additions & 0 deletions sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2800,4 +2800,34 @@ class SubquerySuite extends QueryTest
checkAnswer(df3, Row(7))
}
}

test("SPARK-50091: Handle aggregates in left-hand operand of IN-subquery") {
withView("v1", "v2") {
Seq((1, 2, 2), (1, 5, 3), (2, 0, 4), (3, 7, 7), (3, 8, 8))
.toDF("c1", "c2", "c3")
.createOrReplaceTempView("v1")
Seq((1, 2, 2), (1, 3, 3), (2, 2, 4), (3, 7, 7), (3, 1, 1))
.toDF("col1", "col2", "col3")
.createOrReplaceTempView("v2")

val df1 = sql("SELECT col1, SUM(col2) IN (SELECT c3 FROM v1) FROM v2 GROUP BY col1")
checkAnswer(df1,
Row(1, false) :: Row(2, true) :: Row(3, true) :: Nil)

val df2 = sql("""SELECT
| col1,
| SUM(col2) IN (SELECT c3 FROM v1) and SUM(col3) IN (SELECT c2 FROM v1) AS x
|FROM v2 GROUP BY col1
|ORDER BY col1""".stripMargin)
checkAnswer(df2,
Row(1, false) :: Row(2, false) :: Row(3, true) :: Nil)

val df3 = sql("""SELECT col1, (SUM(col2), SUM(col3)) IN (SELECT c3, c2 FROM v1) AS x
|FROM v2
|GROUP BY col1
|ORDER BY col1""".stripMargin)
checkAnswer(df3,
Row(1, false) :: Row(2, false) :: Row(3, true) :: Nil)
}
}
}
Loading