-
Notifications
You must be signed in to change notification settings - Fork 28.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[SPARK-48503][SQL] Fix invalid scalar subqueries with group-by on non-equivalent columns that were incorrectly allowed #46839
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -20,8 +20,9 @@ package org.apache.spark.sql.catalyst.expressions | |
import scala.collection.mutable.ArrayBuffer | ||
|
||
import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression | ||
import org.apache.spark.sql.catalyst.plans.QueryPlan | ||
import org.apache.spark.sql.catalyst.plans.logical.{Filter, HintInfo, LogicalPlan} | ||
import org.apache.spark.sql.catalyst.optimizer.DecorrelateInnerQuery | ||
import org.apache.spark.sql.catalyst.plans._ | ||
import org.apache.spark.sql.catalyst.plans.logical._ | ||
import org.apache.spark.sql.catalyst.trees.TreePattern._ | ||
import org.apache.spark.sql.errors.QueryCompilationErrors | ||
import org.apache.spark.sql.internal.SQLConf | ||
|
@@ -249,6 +250,73 @@ object SubExprUtils extends PredicateHelper { | |
} | ||
} | ||
} | ||
|
||
/** | ||
* Returns the inner query attributes that are guaranteed to have a single value for each | ||
* outer row. Therefore, a scalar subquery is allowed to group-by on these attributes. | ||
* We can derive these from correlated equality predicates, though we need to take care about | ||
* propagating this through operators like OUTER JOIN or UNION. | ||
* | ||
* Positive examples: x = outer(a) AND y = outer(b) | ||
* Negative examples: | ||
* - x <= outer(a) | ||
* - x + y = outer(a) | ||
* - x = outer(a) OR y = outer(b) | ||
* - y = outer(b) + 1 (this and similar expressions could be supported, but very carefully) | ||
* - An equality under the right side of a LEFT OUTER JOIN, e.g. | ||
* select *, (select count(*) from y left join | ||
* (select * from z where z1 = x1) sub on y2 = z2 group by z1) from x; | ||
* - An equality under UNION e.g. | ||
* select *, (select count(*) from | ||
* (select * from y where y1 = x1 union all select * from y) group by y1) from x; | ||
*/ | ||
def getCorrelatedEquivalentInnerColumns(plan: LogicalPlan): AttributeSet = { | ||
plan match { | ||
case Filter(cond, child) => | ||
val correlated = AttributeSet(splitConjunctivePredicates(cond) | ||
.filter(containsOuter) // TODO: can remove this line to allow e.g. where x = 1 group by x | ||
.filter(DecorrelateInnerQuery.canPullUpOverAgg) | ||
.flatMap(_.references)) | ||
correlated ++ getCorrelatedEquivalentInnerColumns(child) | ||
|
||
case Join(left, right, joinType, _, _) => | ||
joinType match { | ||
case _: InnerLike => | ||
AttributeSet(plan.children.flatMap(child => getCorrelatedEquivalentInnerColumns(child))) | ||
case LeftOuter => getCorrelatedEquivalentInnerColumns(left) | ||
case RightOuter => getCorrelatedEquivalentInnerColumns(right) | ||
case FullOuter => AttributeSet.empty | ||
case LeftSemi => getCorrelatedEquivalentInnerColumns(left) | ||
case LeftAnti => getCorrelatedEquivalentInnerColumns(left) | ||
case _ => AttributeSet.empty | ||
} | ||
|
||
case _: Union => AttributeSet.empty | ||
case Except(left, right, _) => getCorrelatedEquivalentInnerColumns(left) | ||
|
||
case | ||
_: Aggregate | | ||
_: Distinct | | ||
_: Intersect | | ||
_: GlobalLimit | | ||
_: LocalLimit | | ||
_: Offset | | ||
_: Project | | ||
_: Repartition | | ||
_: RepartitionByExpression | | ||
_: RebalancePartitions | | ||
_: Sample | | ||
_: Sort | | ||
_: Window | | ||
_: Tail | | ||
_: WithCTE | | ||
_: Range | | ||
_: SubqueryAlias => | ||
AttributeSet(plan.children.flatMap(child => getCorrelatedEquivalentInnerColumns(child))) | ||
|
||
case _ => AttributeSet.empty | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The list of operators handled here is by no means comprehensive and ensuring it covers enough is tricky. I used the list in LogicalPlanVisitor as a starting point, but in my testing I discovered that e.g. SubqueryAlias also needs to be handled to cover cases with FROM subqueries inside the scalar subquery. Suggestions on other important operators to handle or other potential approaches welcome. (In the long run I think we need to replace this entire check with a runtime check as described in https://issues.apache.org/jira/browse/SPARK-48501, but that's highly nontrivial) |
||
} | ||
} | ||
} | ||
|
||
/** | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,206 @@ | ||
-- Automatically generated by SQLQueryTestSuite | ||
-- !query | ||
create temp view x (x1, x2) as values (1, 1), (2, 2) | ||
-- !query analysis | ||
CreateViewCommand `x`, [(x1,None), (x2,None)], values (1, 1), (2, 2), false, false, LocalTempView, UNSUPPORTED, true | ||
+- LocalRelation [col1#x, col2#x] | ||
|
||
|
||
-- !query | ||
create temp view y (y1, y2) as values (2, 0), (3, -1) | ||
-- !query analysis | ||
CreateViewCommand `y`, [(y1,None), (y2,None)], values (2, 0), (3, -1), false, false, LocalTempView, UNSUPPORTED, true | ||
+- LocalRelation [col1#x, col2#x] | ||
|
||
|
||
-- !query | ||
create temp view z (z1, z2) as values (1, 0), (1, 1) | ||
-- !query analysis | ||
CreateViewCommand `z`, [(z1,None), (z2,None)], values (1, 0), (1, 1), false, false, LocalTempView, UNSUPPORTED, true | ||
+- LocalRelation [col1#x, col2#x] | ||
|
||
|
||
-- !query | ||
select * from x where (select count(*) from y where y1 = x1 group by y1) = 1 | ||
-- !query analysis | ||
Project [x1#x, x2#x] | ||
+- Filter (scalar-subquery#x [x1#x] = cast(1 as bigint)) | ||
: +- Aggregate [y1#x], [count(1) AS count(1)#xL] | ||
: +- Filter (y1#x = outer(x1#x)) | ||
: +- SubqueryAlias y | ||
: +- View (`y`, [y1#x, y2#x]) | ||
: +- Project [cast(col1#x as int) AS y1#x, cast(col2#x as int) AS y2#x] | ||
: +- LocalRelation [col1#x, col2#x] | ||
+- SubqueryAlias x | ||
+- View (`x`, [x1#x, x2#x]) | ||
+- Project [cast(col1#x as int) AS x1#x, cast(col2#x as int) AS x2#x] | ||
+- LocalRelation [col1#x, col2#x] | ||
|
||
|
||
-- !query | ||
select * from x where (select count(*) from y where y1 = x1 group by x1) = 1 | ||
-- !query analysis | ||
org.apache.spark.sql.catalyst.ExtendedAnalysisException | ||
{ | ||
"errorClass" : "UNSUPPORTED_SUBQUERY_EXPRESSION_CATEGORY.CORRELATED_REFERENCE", | ||
"sqlState" : "0A000", | ||
"messageParameters" : { | ||
"sqlExprs" : "\"x1\"" | ||
}, | ||
"queryContext" : [ { | ||
"objectType" : "", | ||
"objectName" : "", | ||
"startIndex" : 61, | ||
"stopIndex" : 71, | ||
"fragment" : "group by x1" | ||
} ] | ||
} | ||
|
||
|
||
-- !query | ||
select * from x where (select count(*) from y where y1 > x1 group by x1) = 1 | ||
-- !query analysis | ||
org.apache.spark.sql.catalyst.ExtendedAnalysisException | ||
{ | ||
"errorClass" : "UNSUPPORTED_SUBQUERY_EXPRESSION_CATEGORY.CORRELATED_REFERENCE", | ||
"sqlState" : "0A000", | ||
"messageParameters" : { | ||
"sqlExprs" : "\"x1\"" | ||
}, | ||
"queryContext" : [ { | ||
"objectType" : "", | ||
"objectName" : "", | ||
"startIndex" : 61, | ||
"stopIndex" : 71, | ||
"fragment" : "group by x1" | ||
} ] | ||
} | ||
|
||
|
||
-- !query | ||
select * from x where (select count(*) from y where y1 > x1 group by y1) = 1 | ||
-- !query analysis | ||
org.apache.spark.sql.catalyst.ExtendedAnalysisException | ||
{ | ||
"errorClass" : "UNSUPPORTED_SUBQUERY_EXPRESSION_CATEGORY.NON_CORRELATED_COLUMNS_IN_GROUP_BY", | ||
"sqlState" : "0A000", | ||
"messageParameters" : { | ||
"value" : "y1" | ||
}, | ||
"queryContext" : [ { | ||
"objectType" : "", | ||
"objectName" : "", | ||
"startIndex" : 23, | ||
"stopIndex" : 72, | ||
"fragment" : "(select count(*) from y where y1 > x1 group by y1)" | ||
} ] | ||
} | ||
|
||
|
||
-- !query | ||
select *, (select count(*) from y where y1 + y2 = x1 group by y1) from x | ||
-- !query analysis | ||
org.apache.spark.sql.catalyst.ExtendedAnalysisException | ||
{ | ||
"errorClass" : "UNSUPPORTED_SUBQUERY_EXPRESSION_CATEGORY.NON_CORRELATED_COLUMNS_IN_GROUP_BY", | ||
"sqlState" : "0A000", | ||
"messageParameters" : { | ||
"value" : "y1" | ||
}, | ||
"queryContext" : [ { | ||
"objectType" : "", | ||
"objectName" : "", | ||
"startIndex" : 11, | ||
"stopIndex" : 65, | ||
"fragment" : "(select count(*) from y where y1 + y2 = x1 group by y1)" | ||
} ] | ||
} | ||
|
||
|
||
-- !query | ||
select *, (select count(*) from y where x1 = y1 and y2 = 1 group by y2) from x | ||
-- !query analysis | ||
org.apache.spark.sql.catalyst.ExtendedAnalysisException | ||
{ | ||
"errorClass" : "UNSUPPORTED_SUBQUERY_EXPRESSION_CATEGORY.NON_CORRELATED_COLUMNS_IN_GROUP_BY", | ||
"sqlState" : "0A000", | ||
"messageParameters" : { | ||
"value" : "y2" | ||
}, | ||
"queryContext" : [ { | ||
"objectType" : "", | ||
"objectName" : "", | ||
"startIndex" : 11, | ||
"stopIndex" : 71, | ||
"fragment" : "(select count(*) from y where x1 = y1 and y2 = 1 group by y2)" | ||
} ] | ||
} | ||
|
||
|
||
-- !query | ||
select *, (select count(*) from (select * from y where y1 = x1 union all select * from y) sub group by y1) from x | ||
-- !query analysis | ||
org.apache.spark.sql.catalyst.ExtendedAnalysisException | ||
{ | ||
"errorClass" : "UNSUPPORTED_SUBQUERY_EXPRESSION_CATEGORY.NON_CORRELATED_COLUMNS_IN_GROUP_BY", | ||
"sqlState" : "0A000", | ||
"messageParameters" : { | ||
"value" : "y1" | ||
}, | ||
"queryContext" : [ { | ||
"objectType" : "", | ||
"objectName" : "", | ||
"startIndex" : 11, | ||
"stopIndex" : 106, | ||
"fragment" : "(select count(*) from (select * from y where y1 = x1 union all select * from y) sub group by y1)" | ||
} ] | ||
} | ||
|
||
|
||
-- !query | ||
select *, (select count(*) from y left join (select * from z where z1 = x1) sub on y2 = z2 group by z1) from x | ||
-- !query analysis | ||
org.apache.spark.sql.catalyst.ExtendedAnalysisException | ||
{ | ||
"errorClass" : "UNSUPPORTED_SUBQUERY_EXPRESSION_CATEGORY.NON_CORRELATED_COLUMNS_IN_GROUP_BY", | ||
"sqlState" : "0A000", | ||
"messageParameters" : { | ||
"value" : "z1" | ||
}, | ||
"queryContext" : [ { | ||
"objectType" : "", | ||
"objectName" : "", | ||
"startIndex" : 11, | ||
"stopIndex" : 103, | ||
"fragment" : "(select count(*) from y left join (select * from z where z1 = x1) sub on y2 = z2 group by z1)" | ||
} ] | ||
} | ||
|
||
|
||
-- !query | ||
set spark.sql.legacy.scalarSubqueryAllowGroupByNonEqualityCorrelatedPredicate = true | ||
-- !query analysis | ||
SetCommand (spark.sql.legacy.scalarSubqueryAllowGroupByNonEqualityCorrelatedPredicate,Some(true)) | ||
|
||
|
||
-- !query | ||
select * from x where (select count(*) from y where y1 > x1 group by y1) = 1 | ||
-- !query analysis | ||
Project [x1#x, x2#x] | ||
+- Filter (scalar-subquery#x [x1#x] = cast(1 as bigint)) | ||
: +- Aggregate [y1#x], [count(1) AS count(1)#xL] | ||
: +- Filter (y1#x > outer(x1#x)) | ||
: +- SubqueryAlias y | ||
: +- View (`y`, [y1#x, y2#x]) | ||
: +- Project [cast(col1#x as int) AS y1#x, cast(col2#x as int) AS y2#x] | ||
: +- LocalRelation [col1#x, col2#x] | ||
+- SubqueryAlias x | ||
+- View (`x`, [x1#x, x2#x]) | ||
+- Project [cast(col1#x as int) AS x1#x, cast(col2#x as int) AS x2#x] | ||
+- LocalRelation [col1#x, col2#x] | ||
|
||
|
||
-- !query | ||
reset spark.sql.legacy.scalarSubqueryAllowGroupByNonEqualityCorrelatedPredicate | ||
-- !query analysis | ||
ResetCommand spark.sql.legacy.scalarSubqueryAllowGroupByNonEqualityCorrelatedPredicate |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,28 @@ | ||
-- Tests for scalar subquery with a group-by. Only a group-by that guarantees a single row result is allowed. See SPARK-48503 | ||
|
||
--ONLY_IF spark | ||
|
||
create temp view x (x1, x2) as values (1, 1), (2, 2); | ||
create temp view y (y1, y2) as values (2, 0), (3, -1); | ||
create temp view z (z1, z2) as values (1, 0), (1, 1); | ||
|
||
-- Legal queries | ||
select * from x where (select count(*) from y where y1 = x1 group by y1) = 1; | ||
select * from x where (select count(*) from y where y1 = x1 group by x1) = 1; | ||
select * from x where (select count(*) from y where y1 > x1 group by x1) = 1; | ||
|
||
-- Illegal queries | ||
select * from x where (select count(*) from y where y1 > x1 group by y1) = 1; | ||
select *, (select count(*) from y where y1 + y2 = x1 group by y1) from x; | ||
|
||
-- Equality with literal - disallowed currently but can actually be allowed | ||
select *, (select count(*) from y where x1 = y1 and y2 = 1 group by y2) from x; | ||
|
||
-- Certain other operators like OUTER JOIN or UNION between the correlating filter and the group-by also can cause the scalar subquery to return multiple values and hence make the query illegal. | ||
select *, (select count(*) from (select * from y where y1 = x1 union all select * from y) sub group by y1) from x; | ||
select *, (select count(*) from y left join (select * from z where z1 = x1) sub on y2 = z2 group by z1) from x; -- The correlation below the join is unsupported in Spark anyway, but when we do support it this query should still be disallowed. | ||
|
||
-- Test legacy behavior conf | ||
set spark.sql.legacy.scalarSubqueryAllowGroupByNonEqualityCorrelatedPredicate = true; | ||
select * from x where (select count(*) from y where y1 > x1 group by y1) = 1; | ||
reset spark.sql.legacy.scalarSubqueryAllowGroupByNonEqualityCorrelatedPredicate; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I intend to enable that in a separate PR, to reduce risk here.