diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index 2bc6785aa40c3..3a18b0c4d2eae 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -905,26 +905,31 @@ trait CheckAnalysis extends PredicateHelper with LookupCatalog with QueryErrorsB // SPARK-18504/SPARK-18814: Block cases where GROUP BY columns // are not part of the correlated columns. - // Note: groupByCols does not contain outer refs - grouping by an outer ref is always ok - val groupByCols = AttributeSet(agg.groupingExpressions.flatMap(_.references)) - // Collect the inner query attributes that are guaranteed to have a single value for each - // outer row. See comment on getCorrelatedEquivalentInnerColumns. - val correlatedEquivalentCols = getCorrelatedEquivalentInnerColumns(query) - val nonEquivalentGroupByCols = groupByCols -- correlatedEquivalentCols + // Collect the inner query expressions that are guaranteed to have a single value for each + // outer row. See comment on getCorrelatedEquivalentInnerExpressions. + val correlatedEquivalentExprs = getCorrelatedEquivalentInnerExpressions(query) + // Grouping expressions, except outer refs and constant expressions - grouping by an + // outer ref or a constant is always ok + val groupByExprs = + ExpressionSet(agg.groupingExpressions.filter(x => !x.isInstanceOf[OuterReference] && + x.references.nonEmpty)) + val nonEquivalentGroupByExprs = groupByExprs -- correlatedEquivalentExprs val invalidCols = if (!SQLConf.get.getConf( SQLConf.LEGACY_SCALAR_SUBQUERY_ALLOW_GROUP_BY_NON_EQUALITY_CORRELATED_PREDICATE)) { - nonEquivalentGroupByCols + nonEquivalentGroupByExprs } else { // Legacy incorrect logic for checking for invalid group-by columns (see SPARK-48503). // Allows any inner attribute that appears in a correlated predicate, even if it is a // non-equality predicate or under an operator that can change the values of the attribute // (see comments on getCorrelatedEquivalentInnerColumns for examples). + // Note: groupByCols does not contain outer refs - grouping by an outer ref is always ok + val groupByCols = AttributeSet(agg.groupingExpressions.flatMap(_.references)) val subqueryColumns = getCorrelatedPredicates(query).flatMap(_.references) .filterNot(conditions.flatMap(_.references).contains) val correlatedCols = AttributeSet(subqueryColumns) val invalidColsLegacy = groupByCols -- correlatedCols - if (!nonEquivalentGroupByCols.isEmpty && invalidColsLegacy.isEmpty) { + if (!nonEquivalentGroupByExprs.isEmpty && invalidColsLegacy.isEmpty) { logWarning(log"Using legacy behavior for " + log"${MDC(LogKeys.CONFIG, SQLConf .LEGACY_SCALAR_SUBQUERY_ALLOW_GROUP_BY_NON_EQUALITY_CORRELATED_PREDICATE.key)}. " + @@ -936,10 +941,16 @@ trait CheckAnalysis extends PredicateHelper with LookupCatalog with QueryErrorsB } if (invalidCols.nonEmpty) { + val names = invalidCols.map { el => + el match { + case attr: Attribute => attr.name + case expr: Expression => expr.toString + } + } expr.failAnalysis( errorClass = "UNSUPPORTED_SUBQUERY_EXPRESSION_CATEGORY." + "NON_CORRELATED_COLUMNS_IN_GROUP_BY", - messageParameters = Map("value" -> invalidCols.map(_.name).mkString(","))) + messageParameters = Map("value" -> names.mkString(","))) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala index 75ca4930cf8c1..174d32c73fc01 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala @@ -20,7 +20,6 @@ package org.apache.spark.sql.catalyst.expressions import scala.collection.mutable.ArrayBuffer import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression -import org.apache.spark.sql.catalyst.optimizer.DecorrelateInnerQuery import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.trees.TreePattern._ @@ -252,8 +251,42 @@ object SubExprUtils extends PredicateHelper { } /** - * Returns the inner query attributes that are guaranteed to have a single value for each - * outer row. Therefore, a scalar subquery is allowed to group-by on these attributes. + * Matches an equality 'expr = func(outer)', where 'func(outer)' depends on outer rows or + * is a constant. + * A scalar subquery is allowed to group-by on 'expr', as they are guaranteed to have exactly + * one value for every outer row. + * Positive examples: + * - x + 1 = outer(a) + * - cast(x as date) = outer(b) + * - y + z = 100 + * - y / 10 = outer(b) + outer(c) + * In all of these examples, the left side of the equality will be returned. + * + * Negative examples: + * - x < outer(b) + * - x = y + * In all of these examples, None will be returned. + * @param expr + * @return + */ + private def getEquivalentToOuter(expr: Expression): Option[Expression] = { + val allowConstants = + SQLConf.get.getConf(SQLConf.SCALAR_SUBQUERY_ALLOW_GROUP_BY_COLUMN_EQUAL_TO_CONSTANT) + + expr match { + case EqualTo(left, x) + if ((allowConstants || containsOuter(x)) && + !x.exists(_.isInstanceOf[Attribute])) => Some(left) + case EqualTo(x, right) + if ((allowConstants || containsOuter(x)) && + !x.exists(_.isInstanceOf[Attribute])) => Some(right) + case _ => None + } + } + + /** + * Returns the inner query expressions that are guaranteed to have a single value for each + * outer row. Therefore, a scalar subquery is allowed to group-by on these expressions. * We can derive these from correlated equality predicates, though we need to take care about * propagating this through operators like OUTER JOIN or UNION. * @@ -261,6 +294,7 @@ object SubExprUtils extends PredicateHelper { * - x = outer(a) AND y = outer(b) * - x = 1 * - x = outer(a) + 1 + * - cast(x as date) = current_date() + outer(b) * * Negative examples: * - x <= outer(a) @@ -274,31 +308,31 @@ object SubExprUtils extends PredicateHelper { * select *, (select count(*) from * (select * from y where y1 = x1 union all select * from y) group by y1) from x; */ - def getCorrelatedEquivalentInnerColumns(plan: LogicalPlan): AttributeSet = { + def getCorrelatedEquivalentInnerExpressions(plan: LogicalPlan): ExpressionSet = { plan match { case Filter(cond, child) => - val correlated = AttributeSet(splitConjunctivePredicates(cond) + val equivalentExprs = ExpressionSet(splitConjunctivePredicates(cond) .filter( SQLConf.get.getConf(SQLConf.SCALAR_SUBQUERY_ALLOW_GROUP_BY_COLUMN_EQUAL_TO_CONSTANT) || containsOuter(_)) - .filter(DecorrelateInnerQuery.canPullUpOverAgg) - .flatMap(_.references)) - correlated ++ getCorrelatedEquivalentInnerColumns(child) + .flatMap(getEquivalentToOuter)) + equivalentExprs ++ getCorrelatedEquivalentInnerExpressions(child) case Join(left, right, joinType, _, _) => joinType match { case _: InnerLike => - AttributeSet(plan.children.flatMap(child => getCorrelatedEquivalentInnerColumns(child))) - case LeftOuter => getCorrelatedEquivalentInnerColumns(left) - case RightOuter => getCorrelatedEquivalentInnerColumns(right) - case FullOuter => AttributeSet.empty - case LeftSemi => getCorrelatedEquivalentInnerColumns(left) - case LeftAnti => getCorrelatedEquivalentInnerColumns(left) - case _ => AttributeSet.empty + ExpressionSet(plan.children.flatMap( + child => getCorrelatedEquivalentInnerExpressions(child))) + case LeftOuter => getCorrelatedEquivalentInnerExpressions(left) + case RightOuter => getCorrelatedEquivalentInnerExpressions(right) + case FullOuter => ExpressionSet().empty + case LeftSemi => getCorrelatedEquivalentInnerExpressions(left) + case LeftAnti => getCorrelatedEquivalentInnerExpressions(left) + case _ => ExpressionSet().empty } - case _: Union => AttributeSet.empty - case Except(left, right, _) => getCorrelatedEquivalentInnerColumns(left) + case _: Union => ExpressionSet().empty + case Except(left, _, _) => getCorrelatedEquivalentInnerExpressions(left) case _: Aggregate | @@ -318,9 +352,10 @@ object SubExprUtils extends PredicateHelper { _: WithCTE | _: Range | _: SubqueryAlias => - AttributeSet(plan.children.flatMap(child => getCorrelatedEquivalentInnerColumns(child))) + ExpressionSet(plan.children.flatMap(child => + getCorrelatedEquivalentInnerExpressions(child))) - case _ => AttributeSet.empty + case _ => ExpressionSet().empty } } } diff --git a/sql/core/src/test/resources/sql-tests/analyzer-results/subquery/scalar-subquery/scalar-subquery-group-by.sql.out b/sql/core/src/test/resources/sql-tests/analyzer-results/subquery/scalar-subquery/scalar-subquery-group-by.sql.out index 671557aa39566..bea91e09b0053 100644 --- a/sql/core/src/test/resources/sql-tests/analyzer-results/subquery/scalar-subquery/scalar-subquery-group-by.sql.out +++ b/sql/core/src/test/resources/sql-tests/analyzer-results/subquery/scalar-subquery/scalar-subquery-group-by.sql.out @@ -109,6 +109,39 @@ Project [x1#x, x2#x, scalar-subquery#x [x1#x && x1#x] AS scalarsubquery(x1, x1)# +- LocalRelation [col1#x, col2#x] +-- !query +select *, (select count(*) from y where x1 = y1 and cast(y2 as double) = x1 + 1 + group by cast(y2 as double)) from x +-- !query analysis +Project [x1#x, x2#x, scalar-subquery#x [x1#x && x1#x] AS scalarsubquery(x1, x1)#xL] +: +- Aggregate [cast(y2#x as double)], [count(1) AS count(1)#xL] +: +- Filter ((outer(x1#x) = y1#x) AND (cast(y2#x as double) = cast((outer(x1#x) + 1) as double))) +: +- SubqueryAlias y +: +- View (`y`, [y1#x, y2#x]) +: +- Project [cast(col1#x as int) AS y1#x, cast(col2#x as int) AS y2#x] +: +- LocalRelation [col1#x, col2#x] ++- SubqueryAlias x + +- View (`x`, [x1#x, x2#x]) + +- Project [cast(col1#x as int) AS x1#x, cast(col2#x as int) AS x2#x] + +- LocalRelation [col1#x, col2#x] + + +-- !query +select *, (select count(*) from y where y2 + 1 = x1 + x2 group by y2 + 1) from x +-- !query analysis +Project [x1#x, x2#x, scalar-subquery#x [x1#x && x2#x] AS scalarsubquery(x1, x2)#xL] +: +- Aggregate [(y2#x + 1)], [count(1) AS count(1)#xL] +: +- Filter ((y2#x + 1) = (outer(x1#x) + outer(x2#x))) +: +- SubqueryAlias y +: +- View (`y`, [y1#x, y2#x]) +: +- Project [cast(col1#x as int) AS y1#x, cast(col2#x as int) AS y2#x] +: +- LocalRelation [col1#x, col2#x] ++- SubqueryAlias x + +- View (`x`, [x1#x, x2#x]) + +- Project [cast(col1#x as int) AS x1#x, cast(col2#x as int) AS x2#x] + +- LocalRelation [col1#x, col2#x] + + -- !query select * from x where (select count(*) from y where y1 > x1 group by y1) = 1 -- !query analysis @@ -149,6 +182,26 @@ org.apache.spark.sql.catalyst.ExtendedAnalysisException } +-- !query +select *, (select count(*) from y where x1 = y1 and y2 + 10 = x1 + 1 group by y2) from x +-- !query analysis +org.apache.spark.sql.catalyst.ExtendedAnalysisException +{ + "errorClass" : "UNSUPPORTED_SUBQUERY_EXPRESSION_CATEGORY.NON_CORRELATED_COLUMNS_IN_GROUP_BY", + "sqlState" : "0A000", + "messageParameters" : { + "value" : "y2" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 11, + "stopIndex" : 81, + "fragment" : "(select count(*) from y where x1 = y1 and y2 + 10 = x1 + 1 group by y2)" + } ] +} + + -- !query select *, (select count(*) from (select * from y where y1 = x1 union all select * from y) sub group by y1) from x -- !query analysis diff --git a/sql/core/src/test/resources/sql-tests/inputs/subquery/scalar-subquery/scalar-subquery-group-by.sql b/sql/core/src/test/resources/sql-tests/inputs/subquery/scalar-subquery/scalar-subquery-group-by.sql index 6787fac75b39a..db7cdc97614cb 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/subquery/scalar-subquery/scalar-subquery-group-by.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/subquery/scalar-subquery/scalar-subquery-group-by.sql @@ -15,10 +15,17 @@ select * from x where (select count(*) from y where y1 > x1 group by x1) = 1; select *, (select count(*) from y where x1 = y1 and y2 = 1 group by y2) from x; -- Group-by column equal to expression with constants and outer refs - legal select *, (select count(*) from y where x1 = y1 and y2 = x1 + 1 group by y2) from x; +-- Group-by expression is the same as the one we filter on - legal +select *, (select count(*) from y where x1 = y1 and cast(y2 as double) = x1 + 1 + group by cast(y2 as double)) from x; +-- Group-by expression equal to an expression that depends on 2 outer refs -- legal +select *, (select count(*) from y where y2 + 1 = x1 + x2 group by y2 + 1) from x; + -- Illegal queries select * from x where (select count(*) from y where y1 > x1 group by y1) = 1; select *, (select count(*) from y where y1 + y2 = x1 group by y1) from x; +select *, (select count(*) from y where x1 = y1 and y2 + 10 = x1 + 1 group by y2) from x; -- Certain other operators like OUTER JOIN or UNION between the correlating filter and the group-by also can cause the scalar subquery to return multiple values and hence make the query illegal. select *, (select count(*) from (select * from y where y1 = x1 union all select * from y) sub group by y1) from x; diff --git a/sql/core/src/test/resources/sql-tests/results/subquery/scalar-subquery/scalar-subquery-group-by.sql.out b/sql/core/src/test/resources/sql-tests/results/subquery/scalar-subquery/scalar-subquery-group-by.sql.out index 85ebd91c28c9c..41cba1f43745f 100644 --- a/sql/core/src/test/resources/sql-tests/results/subquery/scalar-subquery/scalar-subquery-group-by.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/subquery/scalar-subquery/scalar-subquery-group-by.sql.out @@ -93,6 +93,25 @@ struct 2 2 NULL +-- !query +select *, (select count(*) from y where x1 = y1 and cast(y2 as double) = x1 + 1 + group by cast(y2 as double)) from x +-- !query schema +struct +-- !query output +1 1 NULL +2 2 NULL + + +-- !query +select *, (select count(*) from y where y2 + 1 = x1 + x2 group by y2 + 1) from x +-- !query schema +struct +-- !query output +1 1 NULL +2 2 NULL + + -- !query select * from x where (select count(*) from y where y1 > x1 group by y1) = 1 -- !query schema @@ -137,6 +156,28 @@ org.apache.spark.sql.catalyst.ExtendedAnalysisException } +-- !query +select *, (select count(*) from y where x1 = y1 and y2 + 10 = x1 + 1 group by y2) from x +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.catalyst.ExtendedAnalysisException +{ + "errorClass" : "UNSUPPORTED_SUBQUERY_EXPRESSION_CATEGORY.NON_CORRELATED_COLUMNS_IN_GROUP_BY", + "sqlState" : "0A000", + "messageParameters" : { + "value" : "y2" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 11, + "stopIndex" : 81, + "fragment" : "(select count(*) from y where x1 = y1 and y2 + 10 = x1 + 1 group by y2)" + } ] +} + + -- !query select *, (select count(*) from (select * from y where y1 = x1 union all select * from y) sub group by y1) from x -- !query schema