Skip to content

Commit

Permalink
Make test more explicit
Browse files Browse the repository at this point in the history
  • Loading branch information
bersprockets committed Jan 21, 2025
1 parent cb4066a commit b5ee466
Showing 1 changed file with 14 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,11 @@ package org.apache.spark.sql.catalyst.optimizer
import org.apache.spark.sql.catalyst.QueryPlanningTracker
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.expressions.{IsNull, ListQuery, Not}
import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
import org.apache.spark.sql.catalyst.expressions.{Cast, IsNull, ListQuery, Not}
import org.apache.spark.sql.catalyst.plans.{ExistenceJoin, LeftSemi, PlanTest}
import org.apache.spark.sql.catalyst.plans.logical.{Join, LocalRelation, LogicalPlan}
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan}
import org.apache.spark.sql.catalyst.rules.RuleExecutor
import org.apache.spark.sql.types.LongType


class RewriteSubquerySuite extends PlanTest {
Expand Down Expand Up @@ -84,10 +84,16 @@ class RewriteSubquerySuite extends PlanTest {
test("SPARK-50091: Don't put aggregate expression in join condition") {
val relation1 = LocalRelation($"c1".int, $"c2".int, $"c3".int)
val relation2 = LocalRelation($"col1".int, $"col2".int, $"col3".int)
val query = relation2.select(sum($"col2").in(ListQuery(relation1.select($"c3"))))

val optimized = Optimize.execute(query.analyze)
val join = optimized.find(_.isInstanceOf[Join]).get.asInstanceOf[Join]
assert(!join.condition.get.exists(_.isInstanceOf[AggregateExpression]))
val plan = relation2.groupBy()(sum($"col2").in(ListQuery(relation1.select($"c3"))))
val optimized = Optimize.execute(plan.analyze)
val aggregate = relation2
.select($"col2")
.groupBy()(sum($"col2").as("_aggregateexpression"))
val correctAnswer = aggregate
.join(relation1.select(Cast($"c3", LongType).as("c3")),
ExistenceJoin($"exists".boolean.withNullability(false)),
Some($"_aggregateexpression" === $"c3"))
.select($"exists".as("(sum(col2) IN (listquery()))")).analyze
comparePlans(optimized, correctAnswer)
}
}

0 comments on commit b5ee466

Please sign in to comment.