Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-50091][SQL] Handle case of aggregates in left-hand operand of IN-subquery #48627

Closed
wants to merge 17 commits into from

Conversation

bersprockets
Copy link
Contributor

What changes were proposed in this pull request?

This PR adds code to RewritePredicateSubquery#apply to explicitly handle the case where an Aggregate node contains an aggregate expression in the left-hand operand of an IN-subquery expression. The explicit handler moves the IN-subquery expressions out of the Aggregate and into a parent Project node. The Aggregate will continue to perform the aggregations that were used as an operand to the IN-subquery expression, but will not include the IN-subquery expression itself. After pulling up IN-subquery expressions into a Project node, RewritePredicateSubquery#apply is called again to handle the Project as a UnaryNode. The Join will now be inserted between the Project and the Aggregate node, and the join condition will use an attribute rather than an aggregate expression, e.g.:

Project [col1#32, exists#42 AS (sum(col2) IN (listquery()))#40]
+- Join ExistenceJoin(exists#42), (sum(col2)#41L = c2#39L)
   :- Aggregate [col1#32], [col1#32, sum(col2#33) AS sum(col2)#41L]
   :  +- LocalRelation [col1#32, col2#33]
   +- LocalRelation [c2#39L]

sum(col2)#41L in the above join condition, despite how it looks, is the name of the attribute, not an aggregate expression.

Why are the changes needed?

The following query fails:

create or replace temp view v1(c1, c2) as values (1, 2), (1, 3), (2, 2), (3, 7), (3, 1);
create or replace temp view v2(col1, col2) as values (1, 2), (1, 3), (2, 2), (3, 7), (3, 1);

select col1, sum(col2) in (select c2 from v1)
from v2 group by col1;

It fails with this error:

[INTERNAL_ERROR] Cannot generate code for expression: sum(input[1, int, false]) SQLSTATE: XX000

With SPARK_TESTING=1, it fails with this error:

[PLAN_VALIDATION_FAILED_RULE_IN_BATCH] Rule org.apache.spark.sql.catalyst.optimizer.RewritePredicateSubquery in batch RewriteSubquery generated an invalid plan: Special expressions are placed in the wrong plan:
Aggregate [col1#11], [col1#11, first(exists#20, false) AS (sum(col2) IN (listquery()))#19]
+- Join ExistenceJoin(exists#20), (sum(col2#12) = c2#18L)
   :- LocalRelation [col1#11, col2#12]
   +- LocalRelation [c2#18L]

The issue is that RewritePredicateSubquery builds a Join operator where the join condition contains an aggregate expression.

The bug is in the handler for UnaryNode in RewritePredicateSubquery#apply, which adds a Join below the Aggregate and assumes that the left-hand operand of IN-subquery can be used in the join condition. This works fine for most cases, but not when the left-hand operand is an aggregate expression.

This PR moves the offending IN-subqueries to a Project node, with the aggregates replaced by attributes referring to the aggregate expressions. The resulting join condition now uses those attributes rather than the actual aggregate expressions.

Does this PR introduce any user-facing change?

No, other than allowing this type of query to succeed.

How was this patch tested?

New unit tests.

Was this patch authored or co-authored using generative AI tooling?

No.

@github-actions github-actions bot added the SQL label Oct 23, 2024
@bersprockets bersprockets force-pushed the aggregate_in_set_issue branch from 840748d to b073289 Compare October 29, 2024 17:44
@bersprockets bersprockets force-pushed the aggregate_in_set_issue branch from b073289 to 7328f31 Compare November 6, 2024 21:56
@bersprockets bersprockets force-pushed the aggregate_in_set_issue branch from 7328f31 to 0d31cdb Compare November 15, 2024 02:31
@bersprockets
Copy link
Contributor Author

cc @cloud-fan

@bersprockets bersprockets force-pushed the aggregate_in_set_issue branch from 0d31cdb to 3319192 Compare November 20, 2024 19:15
Copy link
Contributor

@dtenedor dtenedor left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the fix!!

@@ -245,6 +266,55 @@ object RewritePredicateSubquery extends Rule[LogicalPlan] with PredicateHelper {
condition = Some(newCondition)))
}
}
case a: Aggregate if exprsContainsAggregateInSubquery(a.aggregateExpressions) =>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This file is already over 1000 lines long, can we move this logic to a helper object in another file to improve the code health?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@dtenedor

I could move the entire new handler into a helper function in another file.

On the other hand, this file contains 6 rules, all related to subqueries in one way or another. They could be split up (in a separate refactor, not by this PR).

@@ -245,6 +266,55 @@ object RewritePredicateSubquery extends Rule[LogicalPlan] with PredicateHelper {
condition = Some(newCondition)))
}
}
case a: Aggregate if exprsContainsAggregateInSubquery(a.aggregateExpressions) =>
// find expressions with an IN-subquery whose left-hand operand contains aggregates
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please express the comments as full sentences (imperative is OK) starting with capital letters and ending in punctuation.

}

val inSubqueryMap = inSubqueryMapping.toMap
// get all aggregate expressions found in left-hand operands of IN-subqueries
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's a bit hard to follow this logic in the code. Can you add a comment with a brief example, showing the query plan and the steps performed here?

case ae: Expression if inSubqueryMap.contains(ae) =>
// replace the expression with an aliased aggregate expression
inSubqueryMap(ae).map(aggregateExprAliasMap(_))
case ae @ _ => Seq(ae)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
case ae @ _ => Seq(ae)
case ae => Seq(ae)

// patch any aggregate expression with its corresponding attribute
case a: AggregateExpression => aggregateExprAttrMap(a)
}.asInstanceOf[NamedExpression]
case ae @ _ => ae.toAttribute
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
case ae @ _ => ae.toAttribute
case ae => ae.toAttribute

@cloud-fan
Copy link
Contributor

cc @agubichev @andylam-db

@bersprockets bersprockets force-pushed the aggregate_in_set_issue branch from a5206a6 to 9e6a688 Compare November 26, 2024 17:21
Copy link
Contributor

@attilapiros attilapiros left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM (just a tiny typo in the comments) but let's wait for a committer who is more familiar in this area

// +- LocalRelation [col1#28, col2#29]
//
// Note that the Aggregate node contains the IN-subquery and the left-hand
// side of the IN-subquery is an aggregate expression (sum(col2#28)).
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
// side of the IN-subquery is an aggregate expression (sum(col2#28)).
// side of the IN-subquery is an aggregate expression (sum(col2#29)).

//
// The transformation pulled the IN-subquery up into a Project. The left-hand side of the
// IN-subquery is now an attribute (sum(col2)#36L) that refers to the actual aggregation
// which is still performed in the Aggregate node (sum(col2#28) AS sum(col2)#36L). The Unary
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
// which is still performed in the Aggregate node (sum(col2#28) AS sum(col2)#36L). The Unary
// which is still performed in the Aggregate node (sum(col2#29) AS sum(col2)#36L). The Unary

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch!

@bersprockets bersprockets force-pushed the aggregate_in_set_issue branch from ba42748 to ccf7302 Compare December 3, 2024 01:13
@bersprockets bersprockets force-pushed the aggregate_in_set_issue branch from ccf7302 to a9434ea Compare January 1, 2025 22:18

// Reapply this rule, but now with all interesting expressions
// from Aggregate.aggregateExpressions pulled up into a Project node.
apply(newProj)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This reminds me of the rule RewriteWithExpression, which also needs to rewrite Aggregate first. We should not call apply here in the middle of plan traveral, as apply transforms the plan again, and leads to O(n^2) complexity. Instead of, we should also add a util function that rewrites UnaryNode (not transforms the full tree) and call it here and the original case match for UnaryNode.

// withInSubquery will be a List containing a single Alias expression:
//
// List(sum(col2#12) IN (list#8 []) AS (...)#19)
val withInSubquery = a.aggregateExpressions.filter(exprContainsAggregateInSubquery(_))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Once we detect such InSubquery, I think it's much simpler to normalize the Aggregate node to pull up the full result projection to a new Project node, instead of only rewriting the problematic InSubquery. This is also how RewriteWithExpression does it and the code is much simpler and less error-prone. We can even create a util function to reuse the code in RewriteWithExpression.

Copy link
Contributor

@cloud-fan cloud-fan Jan 6, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

BTW I think it's better to always build the query plan tree with this normalized form (Aggregate should only do grouping and aggregating, projection should always happen in Project), but this is a much bigger topic.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am delayed in responding to review comments: I not around my laptop much until next week.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I didn't create a util function because the PhysicalAggregation extractor does almost all the heavy lifting and the version of the code in RewriteWithExpression called applyInternal on the new Aggregate node before making it a child of the new Project node.

@bersprockets bersprockets force-pushed the aggregate_in_set_issue branch from a866ebe to b5ee466 Compare January 21, 2025 00:21
@@ -79,4 +80,20 @@ class RewriteSubquerySuite extends PlanTest {
Optimize.executeAndTrack(query.analyze, tracker)
assert(tracker.rules(RewritePredicateSubquery.ruleName).numEffectiveInvocations == 0)
}

test("SPARK-50091: Don't put aggregate expression in join condition") {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I also updated this test to check the whole optimized plan rather than simply testing that the join condition does not have an aggregate expression.

// which are still performed in the Aggregate node (sum(col2#18) and sum(col3#19)).
case p @ PhysicalAggregation(
groupingExpressions, aggregateExpressions, resultExpressions, child)
if exprsContainsAggregateInSubquery(p.expressions) =>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if exprsContainsAggregateInSubquery(p.expressions) =>
if exprsContainsAggregateInSubquery(resultExpressions) =>

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This rewrite only pulls out subquery expressions for Aggregate#aggregateExpressions, not grouping expressions.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Re: if exprsContainsAggregateInSubquery(resultExpressions) =>.

That won't work withexprsContainsAggregateInSubquery as it currently stands, since that function looks for in-subqueries with aggregate expressions in the left-hand operand. resultExpressions has the aggregate expressions replaced with attributes, so exprsContainsAggregateInSubquery would never trigger.

Alternatively, I could do

if exprsContainsAggregateInSubquery(p.asInstanceOf[Aggregate].aggregateExpressions) =>

which is kind of ugly, but does the trick.

Another alternative: I'm the only one calling exprsContainsAggregateInSubquery, so I could change it to return true if there are any in-subqueries at all with no regard to characteristics of the left-hand operand. We would end up rewriting some cases that wouldn't otherwise cause trouble.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah OK, let's keep it as it is

@@ -2800,4 +2800,32 @@ class SubquerySuite extends QueryTest
checkAnswer(df3, Row(7))
}
}

test("SPARK-50091: Handle aggregates in left-hand operand of IN-subquery") {
withTable("v1", "v2") {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
withTable("v1", "v2") {
withTempView("v1", "v2") {

test("SPARK-50091: Handle aggregates in left-hand operand of IN-subquery") {
withTable("v1", "v2") {
sql("""CREATE OR REPLACE TEMP VIEW v1 (c1, c2, c3) AS VALUES
|(1, 2, 2), (1, 5, 3), (2, 0, 4), (3, 7, 7), (3, 8, 8)""".stripMargin)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Seq((1, 2, 2), (1, 5, 3), ...).toDF("c1", "c2", "c3").createTempView

@cloud-fan
Copy link
Contributor

thanks, merging to master/4.0!

@cloud-fan cloud-fan closed this in e02ff1c Jan 23, 2025
cloud-fan pushed a commit that referenced this pull request Jan 23, 2025
…IN-subquery

### What changes were proposed in this pull request?

This PR adds code to `RewritePredicateSubquery#apply` to explicitly handle the case where an `Aggregate` node contains an aggregate expression in the left-hand operand of an IN-subquery expression. The explicit handler moves the IN-subquery expressions out of the `Aggregate` and into a parent `Project` node. The `Aggregate` will continue to perform the aggregations that were used as an operand to the IN-subquery expression, but will not include the IN-subquery expression itself. After pulling up IN-subquery expressions into a Project node, `RewritePredicateSubquery#apply` is called again to handle the `Project` as a `UnaryNode`. The `Join` will now be inserted between the `Project` and the `Aggregate` node, and the join condition will use an attribute rather than an aggregate expression, e.g.:
```
Project [col1#32, exists#42 AS (sum(col2) IN (listquery()))#40]
+- Join ExistenceJoin(exists#42), (sum(col2)#41L = c2#39L)
   :- Aggregate [col1#32], [col1#32, sum(col2#33) AS sum(col2)#41L]
   :  +- LocalRelation [col1#32, col2#33]
   +- LocalRelation [c2#39L]
```
`sum(col2)#41L` in the above join condition, despite how it looks, is the name of the attribute, not an aggregate expression.

### Why are the changes needed?

The following query fails:
```
create or replace temp view v1(c1, c2) as values (1, 2), (1, 3), (2, 2), (3, 7), (3, 1);
create or replace temp view v2(col1, col2) as values (1, 2), (1, 3), (2, 2), (3, 7), (3, 1);

select col1, sum(col2) in (select c2 from v1)
from v2 group by col1;
```
It fails with this error:
```
[INTERNAL_ERROR] Cannot generate code for expression: sum(input[1, int, false]) SQLSTATE: XX000
```
With SPARK_TESTING=1, it fails with this error:
```
[PLAN_VALIDATION_FAILED_RULE_IN_BATCH] Rule org.apache.spark.sql.catalyst.optimizer.RewritePredicateSubquery in batch RewriteSubquery generated an invalid plan: Special expressions are placed in the wrong plan:
Aggregate [col1#11], [col1#11, first(exists#20, false) AS (sum(col2) IN (listquery()))#19]
+- Join ExistenceJoin(exists#20), (sum(col2#12) = c2#18L)
   :- LocalRelation [col1#11, col2#12]
   +- LocalRelation [c2#18L]
```
The issue is that `RewritePredicateSubquery` builds a `Join` operator where the join condition contains an aggregate expression.

The bug is in the handler for `UnaryNode` in `RewritePredicateSubquery#apply`, which adds a `Join` below the `Aggregate` and assumes that the left-hand operand of IN-subquery can be used in the join condition. This works fine for most cases, but not when the left-hand operand is an aggregate expression.

This PR moves the offending IN-subqueries to a `Project` node, with the aggregates replaced by attributes referring to the aggregate expressions. The resulting join condition now uses those attributes rather than the actual aggregate expressions.

### Does this PR introduce _any_ user-facing change?

No, other than allowing this type of query to succeed.

### How was this patch tested?

New unit tests.

### Was this patch authored or co-authored using generative AI tooling?

No.

Closes #48627 from bersprockets/aggregate_in_set_issue.

Authored-by: Bruce Robbins <bersprockets@gmail.com>
Signed-off-by: Wenchen Fan <wenchen@databricks.com>
(cherry picked from commit e02ff1c)
Signed-off-by: Wenchen Fan <wenchen@databricks.com>
@cloud-fan
Copy link
Contributor

@bersprockets feel free to open a 3.5 backport if it's also an issue there.

bersprockets added a commit to bersprockets/spark that referenced this pull request Jan 24, 2025
…IN-subquery

This PR adds code to `RewritePredicateSubquery#apply` to explicitly handle the case where an `Aggregate` node contains an aggregate expression in the left-hand operand of an IN-subquery expression. The explicit handler moves the IN-subquery expressions out of the `Aggregate` and into a parent `Project` node. The `Aggregate` will continue to perform the aggregations that were used as an operand to the IN-subquery expression, but will not include the IN-subquery expression itself. After pulling up IN-subquery expressions into a Project node, `RewritePredicateSubquery#apply` is called again to handle the `Project` as a `UnaryNode`. The `Join` will now be inserted between the `Project` and the `Aggregate` node, and the join condition will use an attribute rather than an aggregate expression, e.g.:
```
Project [col1#32, exists#42 AS (sum(col2) IN (listquery()))apache#40]
+- Join ExistenceJoin(exists#42), (sum(col2)#41L = c2#39L)
   :- Aggregate [col1#32], [col1#32, sum(col2#33) AS sum(col2)#41L]
   :  +- LocalRelation [col1#32, col2#33]
   +- LocalRelation [c2#39L]
```
`sum(col2)#41L` in the above join condition, despite how it looks, is the name of the attribute, not an aggregate expression.

The following query fails:
```
create or replace temp view v1(c1, c2) as values (1, 2), (1, 3), (2, 2), (3, 7), (3, 1);
create or replace temp view v2(col1, col2) as values (1, 2), (1, 3), (2, 2), (3, 7), (3, 1);

select col1, sum(col2) in (select c2 from v1)
from v2 group by col1;
```
It fails with this error:
```
[INTERNAL_ERROR] Cannot generate code for expression: sum(input[1, int, false]) SQLSTATE: XX000
```
With SPARK_TESTING=1, it fails with this error:
```
[PLAN_VALIDATION_FAILED_RULE_IN_BATCH] Rule org.apache.spark.sql.catalyst.optimizer.RewritePredicateSubquery in batch RewriteSubquery generated an invalid plan: Special expressions are placed in the wrong plan:
Aggregate [col1#11], [col1#11, first(exists#20, false) AS (sum(col2) IN (listquery()))apache#19]
+- Join ExistenceJoin(exists#20), (sum(col2#12) = c2#18L)
   :- LocalRelation [col1#11, col2#12]
   +- LocalRelation [c2#18L]
```
The issue is that `RewritePredicateSubquery` builds a `Join` operator where the join condition contains an aggregate expression.

The bug is in the handler for `UnaryNode` in `RewritePredicateSubquery#apply`, which adds a `Join` below the `Aggregate` and assumes that the left-hand operand of IN-subquery can be used in the join condition. This works fine for most cases, but not when the left-hand operand is an aggregate expression.

This PR moves the offending IN-subqueries to a `Project` node, with the aggregates replaced by attributes referring to the aggregate expressions. The resulting join condition now uses those attributes rather than the actual aggregate expressions.

No, other than allowing this type of query to succeed.

New unit tests.

No.

Closes apache#48627 from bersprockets/aggregate_in_set_issue.

Authored-by: Bruce Robbins <bersprockets@gmail.com>
Signed-off-by: Wenchen Fan <wenchen@databricks.com>
bersprockets added a commit to bersprockets/spark that referenced this pull request Jan 24, 2025
…IN-subquery

This PR adds code to `RewritePredicateSubquery#apply` to explicitly handle the case where an `Aggregate` node contains an aggregate expression in the left-hand operand of an IN-subquery expression. The explicit handler moves the IN-subquery expressions out of the `Aggregate` and into a parent `Project` node. The `Aggregate` will continue to perform the aggregations that were used as an operand to the IN-subquery expression, but will not include the IN-subquery expression itself. After pulling up IN-subquery expressions into a Project node, `RewritePredicateSubquery#apply` is called again to handle the `Project` as a `UnaryNode`. The `Join` will now be inserted between the `Project` and the `Aggregate` node, and the join condition will use an attribute rather than an aggregate expression, e.g.:
```
Project [col1#32, exists#42 AS (sum(col2) IN (listquery()))apache#40]
+- Join ExistenceJoin(exists#42), (sum(col2)#41L = c2#39L)
   :- Aggregate [col1#32], [col1#32, sum(col2#33) AS sum(col2)#41L]
   :  +- LocalRelation [col1#32, col2#33]
   +- LocalRelation [c2#39L]
```
`sum(col2)#41L` in the above join condition, despite how it looks, is the name of the attribute, not an aggregate expression.

The following query fails:
```
create or replace temp view v1(c1, c2) as values (1, 2), (1, 3), (2, 2), (3, 7), (3, 1);
create or replace temp view v2(col1, col2) as values (1, 2), (1, 3), (2, 2), (3, 7), (3, 1);

select col1, sum(col2) in (select c2 from v1)
from v2 group by col1;
```
It fails with this error:
```
[INTERNAL_ERROR] Cannot generate code for expression: sum(input[1, int, false]) SQLSTATE: XX000
```
With SPARK_TESTING=1, it fails with this error:
```
[PLAN_VALIDATION_FAILED_RULE_IN_BATCH] Rule org.apache.spark.sql.catalyst.optimizer.RewritePredicateSubquery in batch RewriteSubquery generated an invalid plan: Special expressions are placed in the wrong plan:
Aggregate [col1#11], [col1#11, first(exists#20, false) AS (sum(col2) IN (listquery()))apache#19]
+- Join ExistenceJoin(exists#20), (sum(col2#12) = c2#18L)
   :- LocalRelation [col1#11, col2#12]
   +- LocalRelation [c2#18L]
```
The issue is that `RewritePredicateSubquery` builds a `Join` operator where the join condition contains an aggregate expression.

The bug is in the handler for `UnaryNode` in `RewritePredicateSubquery#apply`, which adds a `Join` below the `Aggregate` and assumes that the left-hand operand of IN-subquery can be used in the join condition. This works fine for most cases, but not when the left-hand operand is an aggregate expression.

This PR moves the offending IN-subqueries to a `Project` node, with the aggregates replaced by attributes referring to the aggregate expressions. The resulting join condition now uses those attributes rather than the actual aggregate expressions.

No, other than allowing this type of query to succeed.

New unit tests.

No.

Closes apache#48627 from bersprockets/aggregate_in_set_issue.

Authored-by: Bruce Robbins <bersprockets@gmail.com>
Signed-off-by: Wenchen Fan <wenchen@databricks.com>
dongjoon-hyun pushed a commit that referenced this pull request Jan 25, 2025
…d of IN-subquery

### What changes were proposed in this pull request?

This is a back-port of #48627.

This PR adds code to `RewritePredicateSubquery#apply` to explicitly handle the case where an `Aggregate` node contains an aggregate expression in the left-hand operand of an IN-subquery expression. The explicit handler moves the IN-subquery expressions out of the `Aggregate` and into a parent `Project` node. The `Aggregate` will continue to perform the aggregations that were used as an operand to the IN-subquery expression, but will not include the IN-subquery expression itself. After pulling up IN-subquery expressions into a Project node, `RewritePredicateSubquery#apply` is called again to handle the `Project` as a `UnaryNode`. The `Join` will now be inserted between the `Project` and the `Aggregate` node, and the join condition will use an attribute rather than an aggregate expression, e.g.:
```
Project [col1#32, exists#42 AS (sum(col2) IN (listquery()))#40]
+- Join ExistenceJoin(exists#42), (sum(col2)#41L = c2#39L)
   :- Aggregate [col1#32], [col1#32, sum(col2#33) AS sum(col2)#41L]
   :  +- LocalRelation [col1#32, col2#33]
   +- LocalRelation [c2#39L]
```
`sum(col2)#41L` in the above join condition, despite how it looks, is the name of the attribute, not an aggregate expression.

### Why are the changes needed?

The following query fails:
```
create or replace temp view v1(c1, c2) as values (1, 2), (1, 3), (2, 2), (3, 7), (3, 1);
create or replace temp view v2(col1, col2) as values (1, 2), (1, 3), (2, 2), (3, 7), (3, 1);

select col1, sum(col2) in (select c2 from v1)
from v2 group by col1;
```
It fails with this error:
```
[INTERNAL_ERROR] Cannot generate code for expression: sum(input[1, int, false]) SQLSTATE: XX000
```
With SPARK_TESTING=1, it fails with this error:
```
[PLAN_VALIDATION_FAILED_RULE_IN_BATCH] Rule org.apache.spark.sql.catalyst.optimizer.RewritePredicateSubquery in batch RewriteSubquery generated an invalid plan: Special expressions are placed in the wrong plan:
Aggregate [col1#11], [col1#11, first(exists#20, false) AS (sum(col2) IN (listquery()))#19]
+- Join ExistenceJoin(exists#20), (sum(col2#12) = c2#18L)
   :- LocalRelation [col1#11, col2#12]
   +- LocalRelation [c2#18L]
```
The issue is that `RewritePredicateSubquery` builds a `Join` operator where the join condition contains an aggregate expression.

The bug is in the handler for `UnaryNode` in `RewritePredicateSubquery#apply`, which adds a `Join` below the `Aggregate` and assumes that the left-hand operand of IN-subquery can be used in the join condition. This works fine for most cases, but not when the left-hand operand is an aggregate expression.

This PR moves the offending IN-subqueries to a `Project` node, with the aggregates replaced by attributes referring to the aggregate expressions. The resulting join condition now uses those attributes rather than the actual aggregate expressions.

### Does this PR introduce _any_ user-facing change?

No, other than allowing this type of query to succeed.

### How was this patch tested?

New unit tests.

### Was this patch authored or co-authored using generative AI tooling?

No.

Closes #49663 from bersprockets/aggregate_in_set_issue_br35.

Authored-by: Bruce Robbins <bersprockets@gmail.com>
Signed-off-by: Dongjoon Hyun <dongjoon@apache.org>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants