Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-31334][SQL] Don't ResolveReference/ResolveMissingReference when Filter condition with aggregate expression #28107

Original file line number Diff line number Diff line change
Expand Up @@ -1393,7 +1393,24 @@ class Analyzer(

case q: LogicalPlan =>
logTrace(s"Attempting to resolve ${q.simpleString(SQLConf.get.maxToStringFields)}")
q.mapExpressions(resolveExpressionTopDown(_, q))
q.mapExpressions { e =>
q match {
case _: Filter if containsAggregate(e) =>
e
case _ =>
resolveExpressionTopDown(e, q)
}
}
}

def containsAggregate(e: Expression): Boolean = {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why can't we reuse ResolveAggregateFunctions.containsAggregate?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why can't we reuse ResolveAggregateFunctions.containsAggregate?

Since here function is still UnresolvedFunction, we can't just reuse this.

e.find {
case func: UnresolvedFunction =>
v1SessionCatalog.lookupFunction(func.name, func.arguments)
.isInstanceOf[AggregateFunction]
case _ =>
false
}.isDefined || e.find(_.isInstanceOf[AggregateExpression]).isDefined
}

def resolveAssignments(
Expand Down Expand Up @@ -1679,7 +1696,9 @@ class Analyzer(
Project(child.output, newSort)
}

case f @ Filter(cond, child) if (!f.resolved || f.missingInput.nonEmpty) && child.resolved =>
case f @ Filter(cond, child) if (!f.resolved || f.missingInput.nonEmpty) && child.resolved
&& !containsAggregate(cond) =>

val (newCond, newChild) = resolveExprsAndAddMissingAttrs(Seq(cond), child)
if (child.output == newChild.output) {
f.copy(condition = newCond.head)
Expand All @@ -1690,6 +1709,16 @@ class Analyzer(
}
}

def containsAggregate(e: Expression): Boolean = {
e.find {
case func: UnresolvedFunction =>
v1SessionCatalog.lookupFunction(func.name, func.arguments)
.isInstanceOf[AggregateFunction]
case _ =>
false
}.isDefined || e.find(_.isInstanceOf[AggregateExpression]).isDefined
}

/**
* This method tries to resolve expressions and find missing attributes recursively. Specially,
* when the expressions used in `Sort` or `Filter` contain unresolved attributes or resolved
Expand Down
40 changes: 40 additions & 0 deletions sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -3494,6 +3494,45 @@ class SQLQuerySuite extends QueryTest with SharedSparkSession with AdaptiveSpark
Seq(Row(Map[Int, Int]()), Row(Map(1 -> 2))))
}

test("SPARK-31334: TypeCoercion should before then ResolveAggregateFunctions") {
Seq(
(1, 3),
(2, 3),
(3, 6),
(4, 7),
(5, 9),
(6, 9)
).toDF("a", "b").createOrReplaceTempView("testData1")

checkAnswer(sql(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does this test fail before your patch?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does this test fail before your patch?

No, it's won't failed, here is for contrast.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So this test is not qualified to reproduce the bug?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So this test is not qualified to reproduce the bug?

The first SQL is used for comparison, and the second can reproduce bugs.
If don't need, we can just delete first one.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's delete

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's delete

Done

"""
| SELECT b, sum(a) as a
| FROM testData1
| GROUP BY b
| HAVING sum(a) > 3
""".stripMargin),
Row(7, 4) :: Row(9, 11) :: Nil)

Seq(
("1", 3),
("2", 3),
("3", 6),
("4", 7),
("5", 9),
("6", 9)
).toDF("a", "b").createOrReplaceTempView("testData2")

checkAnswer(sql(
"""
| SELECT b, sum(a) as a
| FROM testData2
| GROUP BY b
| HAVING sum(a) > 3
""".stripMargin),
Row(7, 4.0) :: Row(9, 11.0) :: Nil)
}


test("SPARK-31242: clone SparkSession should respect sessionInitWithConfigDefaults") {
// Note, only the conf explicitly set in SparkConf(e.g. in SharedSparkSessionBase) would cause
// problem before the fix.
Expand All @@ -3503,6 +3542,7 @@ class SQLQuerySuite extends QueryTest with SharedSparkSession with AdaptiveSpark
assert(SQLConf.get.getConf(SQLConf.CODEGEN_FALLBACK) === true)
}
}

}

case class Foo(bar: Option[String])