Skip to content

Commit

Permalink
Apply @cloud-fan suggestions
Browse files Browse the repository at this point in the history
  • Loading branch information
ostronaut committed Jan 6, 2025
1 parent 2d76099 commit d7efa2b
Show file tree
Hide file tree
Showing 5 changed files with 104 additions and 38 deletions.
5 changes: 0 additions & 5 deletions common/utils/src/main/resources/error/error-conditions.json
Original file line number Diff line number Diff line change
Expand Up @@ -5365,11 +5365,6 @@
"Parameter markers are not allowed in <statement>."
]
},
"PARTITION_BY_MAP" : {
"message" : [
"Cannot use MAP producing expressions to partition a DataFrame, but the type of expression <expr> is <dataType>."
]
},
"PARTITION_BY_VARIANT" : {
"message" : [
"Cannot use VARIANT producing expressions to partition a DataFrame, but the type of expression <expr> is <dataType>."
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -103,13 +103,6 @@ trait CheckAnalysis extends PredicateHelper with LookupCatalog with QueryErrorsB
case _ => None
}

protected def mapExprInPartitionExpression(plan: LogicalPlan): Option[Expression] =
plan match {
case r: RepartitionByExpression =>
r.partitionExpressions.find(e => hasMapType(e.dataType))
case _ => None
}

private def checkLimitLikeClause(name: String, limitExpr: Expression): Unit = {
limitExpr match {
case e if !e.foldable => limitExpr.failAnalysis(
Expand Down Expand Up @@ -894,14 +887,6 @@ trait CheckAnalysis extends PredicateHelper with LookupCatalog with QueryErrorsB
"expr" -> toSQLExpr(variantExpr),
"dataType" -> toSQLType(variantExpr.dataType)))

case o if mapExprInPartitionExpression(o).isDefined =>
val mapExpr = mapExprInPartitionExpression(o).get
o.failAnalysis(
errorClass = "UNSUPPORTED_FEATURE.PARTITION_BY_MAP",
messageParameters = Map(
"expr" -> toSQLExpr(mapExpr),
"dataType" -> toSQLType(mapExpr.dataType)))

case o if o.expressions.exists(!_.deterministic) &&
!operatorAllowsNonDeterministicExpressions(o) &&
!o.isInstanceOf[Project] &&
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,41 +20,35 @@ package org.apache.spark.sql.catalyst.optimizer
import scala.collection.mutable

import org.apache.spark.sql.catalyst.expressions.{Alias, ArrayTransform, CreateNamedStruct, Expression, GetStructField, If, IsNull, LambdaFunction, Literal, MapFromArrays, MapKeys, MapSort, MapValues, NamedExpression, NamedLambdaVariable}
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan, Project}
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan, Project, RepartitionByExpression}
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.catalyst.trees.TreePattern
import org.apache.spark.sql.catalyst.trees.TreePattern.REPARTITION_OPERATION
import org.apache.spark.sql.types.{ArrayType, MapType, StructType}
import org.apache.spark.util.ArrayImplicits.SparkArrayOps

/**
* Adds [[MapSort]] to group expressions containing map columns, as the key/value pairs need to be
* in the correct order before grouping:
* Adds [[MapSort]] to [[Aggregate]] expressions containing map columns,
* as the key/value pairs need to be in the correct order before grouping:
*
* SELECT map_column, COUNT(*) FROM TABLE GROUP BY map_column =>
* SELECT map_column, COUNT(*) FROM TABLE GROUP BY map_column =>
* SELECT _groupingmapsort as map_column, COUNT(*) FROM (
* SELECT map_sort(map_column) as _groupingmapsort FROM TABLE
* ) GROUP BY _groupingmapsort
*/
object InsertMapSortInGroupingExpressions extends Rule[LogicalPlan] {
private def shouldAddMapSort(expr: Expression): Boolean = {
expr.dataType.existsRecursively(_.isInstanceOf[MapType])
}
import InsertMapSortExpression._

override def apply(plan: LogicalPlan): LogicalPlan = {
if (!plan.containsPattern(TreePattern.AGGREGATE)) {
return plan
}
val shouldRewrite = plan.exists {
case agg: Aggregate if agg.groupingExpressions.exists(shouldAddMapSort) => true
case agg: Aggregate if agg.groupingExpressions.exists(mapTypeExistsRecursively) => true
case _ => false
}
if (!shouldRewrite) {
return plan
}

plan transformUpWithNewOutput {
case agg @ Aggregate(groupingExprs, aggregateExpressions, child, _)
if agg.groupingExpressions.exists(shouldAddMapSort) =>
case agg @ Aggregate(groupingExprs, aggregateExpressions, child, hint) =>
val exprToMapSort = new mutable.HashMap[Expression, NamedExpression]
val newGroupingKeys = groupingExprs.map { expr =>
val inserted = insertMapSortRecursively(expr)
Expand All @@ -77,15 +71,56 @@ object InsertMapSortInGroupingExpressions extends Rule[LogicalPlan] {
}.asInstanceOf[NamedExpression]
}
val newChild = Project(child.output ++ exprToMapSort.values, child)
val newAgg = Aggregate(newGroupingKeys, newAggregateExprs, newChild)
val newAgg = Aggregate(newGroupingKeys, newAggregateExprs, newChild, hint)
newAgg -> agg.output.zip(newAgg.output)
}
}
}

/**
* Adds [[MapSort]] to [[RepartitionByExpression]] expressions containing map columns,
* as the key/value pairs need to be in the correct order before repartitioning:
*
* SELECT * FROM TABLE DISTRIBUTE BY map_column =>
* SELECT * FROM TABLE DISTRIBUTE BY map_sort(map_column)
*/
object InsertMapSortInRepartitionExpressions extends Rule[LogicalPlan] {
import InsertMapSortExpression._

override def apply(plan: LogicalPlan): LogicalPlan = {
plan.transformUpWithPruning(_.containsPattern(REPARTITION_OPERATION)) {
case rep @
RepartitionByExpression(partitionExprs, child, optNumPartitions, optAdvisoryPartitionSize)
if rep.partitionExpressions.exists(mapTypeExistsRecursively) =>
val exprToMapSort = new mutable.HashMap[Expression, Expression]
val newPartitionExprs = partitionExprs.map { expr =>
val inserted = insertMapSortRecursively(expr)
if (expr.ne(inserted)) {
exprToMapSort.getOrElseUpdate(expr.canonicalized, inserted)
} else {
expr
}
}
RepartitionByExpression(
newPartitionExprs, child, optNumPartitions, optAdvisoryPartitionSize
)
}
}
}

private[optimizer] object InsertMapSortExpression {

/**
* Inserts MapSort recursively taking into account when it is nested inside a struct or array.
* Returns true if the expression contains a [[MapType]] in DataType tree.
*/
private def insertMapSortRecursively(e: Expression): Expression = {
def mapTypeExistsRecursively(expr: Expression): Boolean = {
expr.dataType.existsRecursively(_.isInstanceOf[MapType])
}

/**
* Inserts [[MapSort]] recursively taking into account when it is nested inside a struct or array.
*/
def insertMapSortRecursively(e: Expression): Expression = {
e.dataType match {
case m: MapType =>
// Check if value type of MapType contains MapType (possibly nested)
Expand Down Expand Up @@ -122,5 +157,4 @@ object InsertMapSortInGroupingExpressions extends Rule[LogicalPlan] {
case _ => e
}
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -321,6 +321,7 @@ abstract class Optimizer(catalogManager: CatalogManager)
// so the grouping keys can only be attribute and literal which makes
// `InsertMapSortInGroupingExpressions` easy to insert `MapSort`.
InsertMapSortInGroupingExpressions,
InsertMapSortInRepartitionExpressions,
ComputeCurrentTime,
ReplaceCurrentLike(catalogManager),
SpecialDatetimeValues,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql

import org.apache.spark.sql.functions.{col, lit, spark_partition_id, typedLit, when}
import org.apache.spark.sql.test.SharedSparkSession

class DataFrameRepartitionSuite extends QueryTest with SharedSparkSession {

test("repartition by MapType") {
Seq("int", "long", "float", "double", "decimal(10, 2)", "string", "varchar(6)").foreach { dt =>
val df = spark.range(20)
.withColumn("c1",
when(col("id") % 3 === 1, typedLit(Map(1 -> 2)))
.when(col("id") % 3 === 2, typedLit(Map(1 -> 1, 2 -> 2)))
.otherwise(typedLit(Map(2 -> 2, 1 -> 1))).cast(s"map<$dt, $dt>"))
.withColumn("c2", typedLit(Map(1 -> null)).cast(s"map<$dt, $dt>"))
.withColumn("c3", lit(null).cast(s"map<$dt, $dt>"))

assertRepartitionNumber(df.repartition(4, col("c1")), 2)
assertRepartitionNumber(df.repartition(4, col("c2")), 1)
assertRepartitionNumber(df.repartition(4, col("c3")), 1)
assertRepartitionNumber(df.repartition(4, col("c1"), col("c2")), 2)
assertRepartitionNumber(df.repartition(4, col("c1"), col("c3")), 2)
assertRepartitionNumber(df.repartition(4, col("c1"), col("c2"), col("c3")), 2)
assertRepartitionNumber(df.repartition(4, col("c2"), col("c3")), 2)
}
}

def assertRepartitionNumber(df: => DataFrame, max: Int): Unit = {
val dfGrouped = df.groupBy(spark_partition_id()).count()
// Result number of partition can be lower or equal to max,
// but no more than that.
assert(dfGrouped.count() <= max, dfGrouped.queryExecution.simpleString)
}
}

0 comments on commit d7efa2b

Please sign in to comment.