diff --git a/common/utils/src/main/resources/error/error-conditions.json b/common/utils/src/main/resources/error/error-conditions.json index be1b43c875598..573e7f3a6a384 100644 --- a/common/utils/src/main/resources/error/error-conditions.json +++ b/common/utils/src/main/resources/error/error-conditions.json @@ -5365,11 +5365,6 @@ "Parameter markers are not allowed in ." ] }, - "PARTITION_BY_MAP" : { - "message" : [ - "Cannot use MAP producing expressions to partition a DataFrame, but the type of expression is ." - ] - }, "PARTITION_BY_VARIANT" : { "message" : [ "Cannot use VARIANT producing expressions to partition a DataFrame, but the type of expression is ." diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index 06e8e26c0cebb..46ca8e793218b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -103,13 +103,6 @@ trait CheckAnalysis extends PredicateHelper with LookupCatalog with QueryErrorsB case _ => None } - protected def mapExprInPartitionExpression(plan: LogicalPlan): Option[Expression] = - plan match { - case r: RepartitionByExpression => - r.partitionExpressions.find(e => hasMapType(e.dataType)) - case _ => None - } - private def checkLimitLikeClause(name: String, limitExpr: Expression): Unit = { limitExpr match { case e if !e.foldable => limitExpr.failAnalysis( @@ -894,14 +887,6 @@ trait CheckAnalysis extends PredicateHelper with LookupCatalog with QueryErrorsB "expr" -> toSQLExpr(variantExpr), "dataType" -> toSQLType(variantExpr.dataType))) - case o if mapExprInPartitionExpression(o).isDefined => - val mapExpr = mapExprInPartitionExpression(o).get - o.failAnalysis( - errorClass = "UNSUPPORTED_FEATURE.PARTITION_BY_MAP", - messageParameters = Map( - "expr" -> toSQLExpr(mapExpr), - "dataType" -> toSQLType(mapExpr.dataType))) - case o if o.expressions.exists(!_.deterministic) && !operatorAllowsNonDeterministicExpressions(o) && !o.isInstanceOf[Project] && diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/InsertMapSortInGroupingExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/InsertMapSortExpression.scala similarity index 68% rename from sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/InsertMapSortInGroupingExpressions.scala rename to sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/InsertMapSortExpression.scala index b6ced6c49a36f..9e153ba6722a9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/InsertMapSortInGroupingExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/InsertMapSortExpression.scala @@ -20,32 +20,27 @@ package org.apache.spark.sql.catalyst.optimizer import scala.collection.mutable import org.apache.spark.sql.catalyst.expressions.{Alias, ArrayTransform, CreateNamedStruct, Expression, GetStructField, If, IsNull, LambdaFunction, Literal, MapFromArrays, MapKeys, MapSort, MapValues, NamedExpression, NamedLambdaVariable} -import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan, Project} +import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan, Project, RepartitionByExpression} import org.apache.spark.sql.catalyst.rules.Rule -import org.apache.spark.sql.catalyst.trees.TreePattern +import org.apache.spark.sql.catalyst.trees.TreePattern.REPARTITION_OPERATION import org.apache.spark.sql.types.{ArrayType, MapType, StructType} import org.apache.spark.util.ArrayImplicits.SparkArrayOps /** - * Adds [[MapSort]] to group expressions containing map columns, as the key/value pairs need to be - * in the correct order before grouping: + * Adds [[MapSort]] to [[Aggregate]] expressions containing map columns, + * as the key/value pairs need to be in the correct order before grouping: * - * SELECT map_column, COUNT(*) FROM TABLE GROUP BY map_column => + * SELECT map_column, COUNT(*) FROM TABLE GROUP BY map_column => * SELECT _groupingmapsort as map_column, COUNT(*) FROM ( * SELECT map_sort(map_column) as _groupingmapsort FROM TABLE * ) GROUP BY _groupingmapsort */ object InsertMapSortInGroupingExpressions extends Rule[LogicalPlan] { - private def shouldAddMapSort(expr: Expression): Boolean = { - expr.dataType.existsRecursively(_.isInstanceOf[MapType]) - } + import InsertMapSortExpression._ override def apply(plan: LogicalPlan): LogicalPlan = { - if (!plan.containsPattern(TreePattern.AGGREGATE)) { - return plan - } val shouldRewrite = plan.exists { - case agg: Aggregate if agg.groupingExpressions.exists(shouldAddMapSort) => true + case agg: Aggregate if agg.groupingExpressions.exists(mapTypeExistsRecursively) => true case _ => false } if (!shouldRewrite) { @@ -53,8 +48,7 @@ object InsertMapSortInGroupingExpressions extends Rule[LogicalPlan] { } plan transformUpWithNewOutput { - case agg @ Aggregate(groupingExprs, aggregateExpressions, child, _) - if agg.groupingExpressions.exists(shouldAddMapSort) => + case agg @ Aggregate(groupingExprs, aggregateExpressions, child, hint) => val exprToMapSort = new mutable.HashMap[Expression, NamedExpression] val newGroupingKeys = groupingExprs.map { expr => val inserted = insertMapSortRecursively(expr) @@ -77,15 +71,56 @@ object InsertMapSortInGroupingExpressions extends Rule[LogicalPlan] { }.asInstanceOf[NamedExpression] } val newChild = Project(child.output ++ exprToMapSort.values, child) - val newAgg = Aggregate(newGroupingKeys, newAggregateExprs, newChild) + val newAgg = Aggregate(newGroupingKeys, newAggregateExprs, newChild, hint) newAgg -> agg.output.zip(newAgg.output) } } +} + +/** + * Adds [[MapSort]] to [[RepartitionByExpression]] expressions containing map columns, + * as the key/value pairs need to be in the correct order before repartitioning: + * + * SELECT * FROM TABLE DISTRIBUTE BY map_column => + * SELECT * FROM TABLE DISTRIBUTE BY map_sort(map_column) + */ +object InsertMapSortInRepartitionExpressions extends Rule[LogicalPlan] { + import InsertMapSortExpression._ + + override def apply(plan: LogicalPlan): LogicalPlan = { + plan.transformUpWithPruning(_.containsPattern(REPARTITION_OPERATION)) { + case rep @ + RepartitionByExpression(partitionExprs, child, optNumPartitions, optAdvisoryPartitionSize) + if rep.partitionExpressions.exists(mapTypeExistsRecursively) => + val exprToMapSort = new mutable.HashMap[Expression, Expression] + val newPartitionExprs = partitionExprs.map { expr => + val inserted = insertMapSortRecursively(expr) + if (expr.ne(inserted)) { + exprToMapSort.getOrElseUpdate(expr.canonicalized, inserted) + } else { + expr + } + } + RepartitionByExpression( + newPartitionExprs, child, optNumPartitions, optAdvisoryPartitionSize + ) + } + } +} + +private[optimizer] object InsertMapSortExpression { /** - * Inserts MapSort recursively taking into account when it is nested inside a struct or array. + * Returns true if the expression contains a [[MapType]] in DataType tree. */ - private def insertMapSortRecursively(e: Expression): Expression = { + def mapTypeExistsRecursively(expr: Expression): Boolean = { + expr.dataType.existsRecursively(_.isInstanceOf[MapType]) + } + + /** + * Inserts [[MapSort]] recursively taking into account when it is nested inside a struct or array. + */ + def insertMapSortRecursively(e: Expression): Expression = { e.dataType match { case m: MapType => // Check if value type of MapType contains MapType (possibly nested) @@ -122,5 +157,4 @@ object InsertMapSortInGroupingExpressions extends Rule[LogicalPlan] { case _ => e } } - } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index b141d2be04c32..faea022ebbba3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -321,6 +321,7 @@ abstract class Optimizer(catalogManager: CatalogManager) // so the grouping keys can only be attribute and literal which makes // `InsertMapSortInGroupingExpressions` easy to insert `MapSort`. InsertMapSortInGroupingExpressions, + InsertMapSortInRepartitionExpressions, ComputeCurrentTime, ReplaceCurrentLike(catalogManager), SpecialDatetimeValues, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameRepartitionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameRepartitionSuite.scala new file mode 100644 index 0000000000000..5de0259598c2a --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameRepartitionSuite.scala @@ -0,0 +1,51 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import org.apache.spark.sql.functions.{col, lit, spark_partition_id, typedLit, when} +import org.apache.spark.sql.test.SharedSparkSession + +class DataFrameRepartitionSuite extends QueryTest with SharedSparkSession { + + test("repartition by MapType") { + Seq("int", "long", "float", "double", "decimal(10, 2)", "string", "varchar(6)").foreach { dt => + val df = spark.range(20) + .withColumn("c1", + when(col("id") % 3 === 1, typedLit(Map(1 -> 2))) + .when(col("id") % 3 === 2, typedLit(Map(1 -> 1, 2 -> 2))) + .otherwise(typedLit(Map(2 -> 2, 1 -> 1))).cast(s"map<$dt, $dt>")) + .withColumn("c2", typedLit(Map(1 -> null)).cast(s"map<$dt, $dt>")) + .withColumn("c3", lit(null).cast(s"map<$dt, $dt>")) + + assertRepartitionNumber(df.repartition(4, col("c1")), 2) + assertRepartitionNumber(df.repartition(4, col("c2")), 1) + assertRepartitionNumber(df.repartition(4, col("c3")), 1) + assertRepartitionNumber(df.repartition(4, col("c1"), col("c2")), 2) + assertRepartitionNumber(df.repartition(4, col("c1"), col("c3")), 2) + assertRepartitionNumber(df.repartition(4, col("c1"), col("c2"), col("c3")), 2) + assertRepartitionNumber(df.repartition(4, col("c2"), col("c3")), 2) + } + } + + def assertRepartitionNumber(df: => DataFrame, max: Int): Unit = { + val dfGrouped = df.groupBy(spark_partition_id()).count() + // Result number of partition can be lower or equal to max, + // but no more than that. + assert(dfGrouped.count() <= max, dfGrouped.queryExecution.simpleString) + } +}