diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteSubquerySuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteSubquerySuite.scala index 67e2312f7670e..c45a761353c85 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteSubquerySuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteSubquerySuite.scala @@ -20,11 +20,11 @@ package org.apache.spark.sql.catalyst.optimizer import org.apache.spark.sql.catalyst.QueryPlanningTracker import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ -import org.apache.spark.sql.catalyst.expressions.{IsNull, ListQuery, Not} -import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression +import org.apache.spark.sql.catalyst.expressions.{Cast, IsNull, ListQuery, Not} import org.apache.spark.sql.catalyst.plans.{ExistenceJoin, LeftSemi, PlanTest} -import org.apache.spark.sql.catalyst.plans.logical.{Join, LocalRelation, LogicalPlan} +import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan} import org.apache.spark.sql.catalyst.rules.RuleExecutor +import org.apache.spark.sql.types.LongType class RewriteSubquerySuite extends PlanTest { @@ -84,10 +84,16 @@ class RewriteSubquerySuite extends PlanTest { test("SPARK-50091: Don't put aggregate expression in join condition") { val relation1 = LocalRelation($"c1".int, $"c2".int, $"c3".int) val relation2 = LocalRelation($"col1".int, $"col2".int, $"col3".int) - val query = relation2.select(sum($"col2").in(ListQuery(relation1.select($"c3")))) - - val optimized = Optimize.execute(query.analyze) - val join = optimized.find(_.isInstanceOf[Join]).get.asInstanceOf[Join] - assert(!join.condition.get.exists(_.isInstanceOf[AggregateExpression])) + val plan = relation2.groupBy()(sum($"col2").in(ListQuery(relation1.select($"c3")))) + val optimized = Optimize.execute(plan.analyze) + val aggregate = relation2 + .select($"col2") + .groupBy()(sum($"col2").as("_aggregateexpression")) + val correctAnswer = aggregate + .join(relation1.select(Cast($"c3", LongType).as("c3")), + ExistenceJoin($"exists".boolean.withNullability(false)), + Some($"_aggregateexpression" === $"c3")) + .select($"exists".as("(sum(col2) IN (listquery()))")).analyze + comparePlans(optimized, correctAnswer) } }