Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-48503][SQL] Fix invalid scalar subqueries with group-by on non-equivalent columns that were incorrectly allowed #46839

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.analysis
import scala.collection.mutable

import org.apache.spark.SparkException
import org.apache.spark.internal.Logging
import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.ExtendedAnalysisException
import org.apache.spark.sql.catalyst.expressions._
Expand All @@ -41,7 +42,7 @@ import org.apache.spark.util.Utils
/**
* Throws user facing errors when passed invalid queries that fail to analyze.
*/
trait CheckAnalysis extends PredicateHelper with LookupCatalog with QueryErrorsBase {
trait CheckAnalysis extends PredicateHelper with LookupCatalog with QueryErrorsBase with Logging {

protected def isView(nameParts: Seq[String]): Boolean

Expand Down Expand Up @@ -912,13 +913,36 @@ trait CheckAnalysis extends PredicateHelper with LookupCatalog with QueryErrorsB

// SPARK-18504/SPARK-18814: Block cases where GROUP BY columns
// are not part of the correlated columns.

// Note: groupByCols does not contain outer refs - grouping by an outer ref is always ok
val groupByCols = AttributeSet(agg.groupingExpressions.flatMap(_.references))
// Collect the local references from the correlated predicate in the subquery.
val subqueryColumns = getCorrelatedPredicates(query).flatMap(_.references)
.filterNot(conditions.flatMap(_.references).contains)
val correlatedCols = AttributeSet(subqueryColumns)
val invalidCols = groupByCols -- correlatedCols
// GROUP BY columns must be a subset of columns in the predicates
// Collect the inner query attributes that are guaranteed to have a single value for each
// outer row. See comment on getCorrelatedEquivalentInnerColumns.
val correlatedEquivalentCols = getCorrelatedEquivalentInnerColumns(query)
val nonEquivalentGroupByCols = groupByCols -- correlatedEquivalentCols

val invalidCols = if (!SQLConf.get.getConf(
SQLConf.LEGACY_SCALAR_SUBQUERY_ALLOW_GROUP_BY_NON_EQUALITY_CORRELATED_PREDICATE)) {
nonEquivalentGroupByCols
} else {
// Legacy incorrect logic for checking for invalid group-by columns (see SPARK-48503).
// Allows any inner attribute that appears in a correlated predicate, even if it is a
// non-equality predicate or under an operator that can change the values of the attribute
// (see comments on getCorrelatedEquivalentInnerColumns for examples).
val subqueryColumns = getCorrelatedPredicates(query).flatMap(_.references)
.filterNot(conditions.flatMap(_.references).contains)
val correlatedCols = AttributeSet(subqueryColumns)
val invalidColsLegacy = groupByCols -- correlatedCols
if (!nonEquivalentGroupByCols.isEmpty && invalidColsLegacy.isEmpty) {
logWarning("Using legacy behavior for " +
s"${SQLConf.LEGACY_SCALAR_SUBQUERY_ALLOW_GROUP_BY_NON_EQUALITY_CORRELATED_PREDICATE
.key}. Query would be rejected with non-legacy behavior but is allowed by " +
s"legacy behavior. Query may be invalid and return wrong results if the scalar " +
s"subquery's group-by outputs multiple rows.")
}
invalidColsLegacy
}

if (invalidCols.nonEmpty) {
expr.failAnalysis(
errorClass = "UNSUPPORTED_SUBQUERY_EXPRESSION_CATEGORY." +
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,9 @@ package org.apache.spark.sql.catalyst.expressions
import scala.collection.mutable.ArrayBuffer

import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
import org.apache.spark.sql.catalyst.plans.QueryPlan
import org.apache.spark.sql.catalyst.plans.logical.{Filter, HintInfo, LogicalPlan}
import org.apache.spark.sql.catalyst.optimizer.DecorrelateInnerQuery
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.trees.TreePattern._
import org.apache.spark.sql.errors.QueryCompilationErrors
import org.apache.spark.sql.internal.SQLConf
Expand Down Expand Up @@ -249,6 +250,73 @@ object SubExprUtils extends PredicateHelper {
}
}
}

/**
* Returns the inner query attributes that are guaranteed to have a single value for each
* outer row. Therefore, a scalar subquery is allowed to group-by on these attributes.
* We can derive these from correlated equality predicates, though we need to take care about
* propagating this through operators like OUTER JOIN or UNION.
*
* Positive examples: x = outer(a) AND y = outer(b)
* Negative examples:
* - x <= outer(a)
* - x + y = outer(a)
* - x = outer(a) OR y = outer(b)
* - y = outer(b) + 1 (this and similar expressions could be supported, but very carefully)
* - An equality under the right side of a LEFT OUTER JOIN, e.g.
* select *, (select count(*) from y left join
* (select * from z where z1 = x1) sub on y2 = z2 group by z1) from x;
* - An equality under UNION e.g.
* select *, (select count(*) from
* (select * from y where y1 = x1 union all select * from y) group by y1) from x;
*/
def getCorrelatedEquivalentInnerColumns(plan: LogicalPlan): AttributeSet = {
plan match {
case Filter(cond, child) =>
val correlated = AttributeSet(splitConjunctivePredicates(cond)
.filter(containsOuter) // TODO: can remove this line to allow e.g. where x = 1 group by x
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I intend to enable that in a separate PR, to reduce risk here.

.filter(DecorrelateInnerQuery.canPullUpOverAgg)
.flatMap(_.references))
correlated ++ getCorrelatedEquivalentInnerColumns(child)

case Join(left, right, joinType, _, _) =>
joinType match {
case _: InnerLike =>
AttributeSet(plan.children.flatMap(child => getCorrelatedEquivalentInnerColumns(child)))
case LeftOuter => getCorrelatedEquivalentInnerColumns(left)
case RightOuter => getCorrelatedEquivalentInnerColumns(right)
case FullOuter => AttributeSet.empty
case LeftSemi => getCorrelatedEquivalentInnerColumns(left)
case LeftAnti => getCorrelatedEquivalentInnerColumns(left)
case _ => AttributeSet.empty
}

case _: Union => AttributeSet.empty
case Except(left, right, _) => getCorrelatedEquivalentInnerColumns(left)

case
_: Aggregate |
_: Distinct |
_: Intersect |
_: GlobalLimit |
_: LocalLimit |
_: Offset |
_: Project |
_: Repartition |
_: RepartitionByExpression |
_: RebalancePartitions |
_: Sample |
_: Sort |
_: Window |
_: Tail |
_: WithCTE |
_: Range |
_: SubqueryAlias =>
AttributeSet(plan.children.flatMap(child => getCorrelatedEquivalentInnerColumns(child)))

case _ => AttributeSet.empty
Copy link
Contributor Author

@jchen5 jchen5 Jun 3, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The list of operators handled here is by no means comprehensive and ensuring it covers enough is tricky. I used the list in LogicalPlanVisitor as a starting point, but in my testing I discovered that e.g. SubqueryAlias also needs to be handled to cover cases with FROM subqueries inside the scalar subquery.

Suggestions on other important operators to handle or other potential approaches welcome.

(In the long run I think we need to replace this entire check with a runtime check as described in https://issues.apache.org/jira/browse/SPARK-48501, but that's highly nontrivial)

}
}
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4910,6 +4910,15 @@ object SQLConf {
.booleanConf
.createWithDefault(false)

val LEGACY_SCALAR_SUBQUERY_ALLOW_GROUP_BY_NON_EQUALITY_CORRELATED_PREDICATE =
buildConf("spark.sql.legacy.scalarSubqueryAllowGroupByNonEqualityCorrelatedPredicate")
.internal()
.doc("When set to true, use incorrect legacy behavior for checking whether a scalar " +
"subquery with a group-by on correlated columns is allowed. See SPARK-48503")
.version("4.0.0")
.booleanConf
.createWithDefault(false)

val ALLOW_SUBQUERY_EXPRESSIONS_IN_LAMBDAS_AND_HIGHER_ORDER_FUNCTIONS =
buildConf("spark.sql.analyzer.allowSubqueryExpressionsInLambdasOrHigherOrderFunctions")
.internal()
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,206 @@
-- Automatically generated by SQLQueryTestSuite
-- !query
create temp view x (x1, x2) as values (1, 1), (2, 2)
-- !query analysis
CreateViewCommand `x`, [(x1,None), (x2,None)], values (1, 1), (2, 2), false, false, LocalTempView, UNSUPPORTED, true
+- LocalRelation [col1#x, col2#x]


-- !query
create temp view y (y1, y2) as values (2, 0), (3, -1)
-- !query analysis
CreateViewCommand `y`, [(y1,None), (y2,None)], values (2, 0), (3, -1), false, false, LocalTempView, UNSUPPORTED, true
+- LocalRelation [col1#x, col2#x]


-- !query
create temp view z (z1, z2) as values (1, 0), (1, 1)
-- !query analysis
CreateViewCommand `z`, [(z1,None), (z2,None)], values (1, 0), (1, 1), false, false, LocalTempView, UNSUPPORTED, true
+- LocalRelation [col1#x, col2#x]


-- !query
select * from x where (select count(*) from y where y1 = x1 group by y1) = 1
-- !query analysis
Project [x1#x, x2#x]
+- Filter (scalar-subquery#x [x1#x] = cast(1 as bigint))
: +- Aggregate [y1#x], [count(1) AS count(1)#xL]
: +- Filter (y1#x = outer(x1#x))
: +- SubqueryAlias y
: +- View (`y`, [y1#x, y2#x])
: +- Project [cast(col1#x as int) AS y1#x, cast(col2#x as int) AS y2#x]
: +- LocalRelation [col1#x, col2#x]
+- SubqueryAlias x
+- View (`x`, [x1#x, x2#x])
+- Project [cast(col1#x as int) AS x1#x, cast(col2#x as int) AS x2#x]
+- LocalRelation [col1#x, col2#x]


-- !query
select * from x where (select count(*) from y where y1 = x1 group by x1) = 1
-- !query analysis
org.apache.spark.sql.catalyst.ExtendedAnalysisException
{
"errorClass" : "UNSUPPORTED_SUBQUERY_EXPRESSION_CATEGORY.CORRELATED_REFERENCE",
"sqlState" : "0A000",
"messageParameters" : {
"sqlExprs" : "\"x1\""
},
"queryContext" : [ {
"objectType" : "",
"objectName" : "",
"startIndex" : 61,
"stopIndex" : 71,
"fragment" : "group by x1"
} ]
}


-- !query
select * from x where (select count(*) from y where y1 > x1 group by x1) = 1
-- !query analysis
org.apache.spark.sql.catalyst.ExtendedAnalysisException
{
"errorClass" : "UNSUPPORTED_SUBQUERY_EXPRESSION_CATEGORY.CORRELATED_REFERENCE",
"sqlState" : "0A000",
"messageParameters" : {
"sqlExprs" : "\"x1\""
},
"queryContext" : [ {
"objectType" : "",
"objectName" : "",
"startIndex" : 61,
"stopIndex" : 71,
"fragment" : "group by x1"
} ]
}


-- !query
select * from x where (select count(*) from y where y1 > x1 group by y1) = 1
-- !query analysis
org.apache.spark.sql.catalyst.ExtendedAnalysisException
{
"errorClass" : "UNSUPPORTED_SUBQUERY_EXPRESSION_CATEGORY.NON_CORRELATED_COLUMNS_IN_GROUP_BY",
"sqlState" : "0A000",
"messageParameters" : {
"value" : "y1"
},
"queryContext" : [ {
"objectType" : "",
"objectName" : "",
"startIndex" : 23,
"stopIndex" : 72,
"fragment" : "(select count(*) from y where y1 > x1 group by y1)"
} ]
}


-- !query
select *, (select count(*) from y where y1 + y2 = x1 group by y1) from x
-- !query analysis
org.apache.spark.sql.catalyst.ExtendedAnalysisException
{
"errorClass" : "UNSUPPORTED_SUBQUERY_EXPRESSION_CATEGORY.NON_CORRELATED_COLUMNS_IN_GROUP_BY",
"sqlState" : "0A000",
"messageParameters" : {
"value" : "y1"
},
"queryContext" : [ {
"objectType" : "",
"objectName" : "",
"startIndex" : 11,
"stopIndex" : 65,
"fragment" : "(select count(*) from y where y1 + y2 = x1 group by y1)"
} ]
}


-- !query
select *, (select count(*) from y where x1 = y1 and y2 = 1 group by y2) from x
-- !query analysis
org.apache.spark.sql.catalyst.ExtendedAnalysisException
{
"errorClass" : "UNSUPPORTED_SUBQUERY_EXPRESSION_CATEGORY.NON_CORRELATED_COLUMNS_IN_GROUP_BY",
"sqlState" : "0A000",
"messageParameters" : {
"value" : "y2"
},
"queryContext" : [ {
"objectType" : "",
"objectName" : "",
"startIndex" : 11,
"stopIndex" : 71,
"fragment" : "(select count(*) from y where x1 = y1 and y2 = 1 group by y2)"
} ]
}


-- !query
select *, (select count(*) from (select * from y where y1 = x1 union all select * from y) sub group by y1) from x
-- !query analysis
org.apache.spark.sql.catalyst.ExtendedAnalysisException
{
"errorClass" : "UNSUPPORTED_SUBQUERY_EXPRESSION_CATEGORY.NON_CORRELATED_COLUMNS_IN_GROUP_BY",
"sqlState" : "0A000",
"messageParameters" : {
"value" : "y1"
},
"queryContext" : [ {
"objectType" : "",
"objectName" : "",
"startIndex" : 11,
"stopIndex" : 106,
"fragment" : "(select count(*) from (select * from y where y1 = x1 union all select * from y) sub group by y1)"
} ]
}


-- !query
select *, (select count(*) from y left join (select * from z where z1 = x1) sub on y2 = z2 group by z1) from x
-- !query analysis
org.apache.spark.sql.catalyst.ExtendedAnalysisException
{
"errorClass" : "UNSUPPORTED_SUBQUERY_EXPRESSION_CATEGORY.NON_CORRELATED_COLUMNS_IN_GROUP_BY",
"sqlState" : "0A000",
"messageParameters" : {
"value" : "z1"
},
"queryContext" : [ {
"objectType" : "",
"objectName" : "",
"startIndex" : 11,
"stopIndex" : 103,
"fragment" : "(select count(*) from y left join (select * from z where z1 = x1) sub on y2 = z2 group by z1)"
} ]
}


-- !query
set spark.sql.legacy.scalarSubqueryAllowGroupByNonEqualityCorrelatedPredicate = true
-- !query analysis
SetCommand (spark.sql.legacy.scalarSubqueryAllowGroupByNonEqualityCorrelatedPredicate,Some(true))


-- !query
select * from x where (select count(*) from y where y1 > x1 group by y1) = 1
-- !query analysis
Project [x1#x, x2#x]
+- Filter (scalar-subquery#x [x1#x] = cast(1 as bigint))
: +- Aggregate [y1#x], [count(1) AS count(1)#xL]
: +- Filter (y1#x > outer(x1#x))
: +- SubqueryAlias y
: +- View (`y`, [y1#x, y2#x])
: +- Project [cast(col1#x as int) AS y1#x, cast(col2#x as int) AS y2#x]
: +- LocalRelation [col1#x, col2#x]
+- SubqueryAlias x
+- View (`x`, [x1#x, x2#x])
+- Project [cast(col1#x as int) AS x1#x, cast(col2#x as int) AS x2#x]
+- LocalRelation [col1#x, col2#x]


-- !query
reset spark.sql.legacy.scalarSubqueryAllowGroupByNonEqualityCorrelatedPredicate
-- !query analysis
ResetCommand spark.sql.legacy.scalarSubqueryAllowGroupByNonEqualityCorrelatedPredicate
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
-- Tests for scalar subquery with a group-by. Only a group-by that guarantees a single row result is allowed. See SPARK-48503

--ONLY_IF spark

create temp view x (x1, x2) as values (1, 1), (2, 2);
create temp view y (y1, y2) as values (2, 0), (3, -1);
create temp view z (z1, z2) as values (1, 0), (1, 1);

-- Legal queries
select * from x where (select count(*) from y where y1 = x1 group by y1) = 1;
select * from x where (select count(*) from y where y1 = x1 group by x1) = 1;
select * from x where (select count(*) from y where y1 > x1 group by x1) = 1;

-- Illegal queries
select * from x where (select count(*) from y where y1 > x1 group by y1) = 1;
select *, (select count(*) from y where y1 + y2 = x1 group by y1) from x;

-- Equality with literal - disallowed currently but can actually be allowed
select *, (select count(*) from y where x1 = y1 and y2 = 1 group by y2) from x;

-- Certain other operators like OUTER JOIN or UNION between the correlating filter and the group-by also can cause the scalar subquery to return multiple values and hence make the query illegal.
select *, (select count(*) from (select * from y where y1 = x1 union all select * from y) sub group by y1) from x;
select *, (select count(*) from y left join (select * from z where z1 = x1) sub on y2 = z2 group by z1) from x; -- The correlation below the join is unsupported in Spark anyway, but when we do support it this query should still be disallowed.

-- Test legacy behavior conf
set spark.sql.legacy.scalarSubqueryAllowGroupByNonEqualityCorrelatedPredicate = true;
select * from x where (select count(*) from y where y1 > x1 group by y1) = 1;
reset spark.sql.legacy.scalarSubqueryAllowGroupByNonEqualityCorrelatedPredicate;
Loading