Skip to content

Commit

Permalink
[SPARK-50525][SQL] Define InsertMapSortInRepartitionExpressions Optim…
Browse files Browse the repository at this point in the history
…izer Rule

### What changes were proposed in this pull request?
In the current version of Spark, its possible to use `MapType` as column for repartitioning. But `MapData` does not implement `equals` and `hashCode` (in according to [SPARK-9415](https://issues.apache.org/jira/browse/SPARK-9415) and [[SPARK-16135][SQL] Remove hashCode and equals in ArrayBasedMapData](#13847)). Considering that, hash value for same Maps can be different.

In an attempt to run `xxhash64` or `hash` function on `MapType`, ```org.apache.spark.sql.catalyst.ExtendedAnalysisException: [DATATYPE_MISMATCH.HASH_MAP_TYPE] Cannot resolve "xxhash64(value)" due to data type mismatch: Input to the function `xxhash64` cannot contain elements of the "MAP" type. In Spark, same maps may have different hashcode, thus hash expressions are prohibited on "MAP" elements. To restore previous behavior set "spark.sql.legacy.allowHashOnMapType" to "true".;``` will be thrown.

Also, when trying to run `ds.distinct(col("value"))`, where `value` has `MapType`, the following exception is thrown: ```org.apache.spark.sql.catalyst.ExtendedAnalysisException: [UNSUPPORTED_FEATURE.SET_OPERATION_ON_MAP_TYPE] The feature is not supported: Cannot have MAP type columns in DataFrame which calls set operations (INTERSECT, EXCEPT, etc.), but the type of column `value` is "MAP<INT, STRING>".;```

With the above consideration, a new `InsertMapSortInRepartitionExpressions` `Rule[LogicalPlan]` was implemented to insert `mapsort` for every `MapType` in `RepartitionByExpression.partitionExpressions`.

### Why are the changes needed?

To keep `repartition` API for MapType consistent.

### Does this PR introduce _any_ user-facing change?

No.

### How was this patch tested?

Unit tests.

### Was this patch authored or co-authored using generative AI tooling?

No

Closes #49144 from ostronaut/features/map_repartition.

Authored-by: Dima <dimanowq@gmail.com>
Signed-off-by: Wenchen Fan <wenchen@databricks.com>
  • Loading branch information
ostronaut authored and cloud-fan committed Jan 10, 2025
1 parent 68305ac commit a4f2870
Show file tree
Hide file tree
Showing 4 changed files with 84 additions and 22 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -884,7 +884,7 @@ trait CheckAnalysis extends PredicateHelper with LookupCatalog with QueryErrorsB
o.failAnalysis(
errorClass = "UNSUPPORTED_FEATURE.PARTITION_BY_VARIANT",
messageParameters = Map(
"expr" -> variantExpr.sql,
"expr" -> toSQLExpr(variantExpr),
"dataType" -> toSQLType(variantExpr.dataType)))

case o if o.expressions.exists(!_.deterministic) &&
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,41 +20,38 @@ package org.apache.spark.sql.catalyst.optimizer
import scala.collection.mutable

import org.apache.spark.sql.catalyst.expressions.{Alias, ArrayTransform, CreateNamedStruct, Expression, GetStructField, If, IsNull, LambdaFunction, Literal, MapFromArrays, MapKeys, MapSort, MapValues, NamedExpression, NamedLambdaVariable}
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan, Project}
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan, Project, RepartitionByExpression}
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.catalyst.trees.TreePattern
import org.apache.spark.sql.catalyst.trees.TreePattern.{AGGREGATE, REPARTITION_OPERATION}
import org.apache.spark.sql.types.{ArrayType, MapType, StructType}
import org.apache.spark.util.ArrayImplicits.SparkArrayOps

/**
* Adds [[MapSort]] to group expressions containing map columns, as the key/value pairs need to be
* in the correct order before grouping:
* Adds [[MapSort]] to [[Aggregate]] expressions containing map columns,
* as the key/value pairs need to be in the correct order before grouping:
*
* SELECT map_column, COUNT(*) FROM TABLE GROUP BY map_column =>
* SELECT map_column, COUNT(*) FROM TABLE GROUP BY map_column =>
* SELECT _groupingmapsort as map_column, COUNT(*) FROM (
* SELECT map_sort(map_column) as _groupingmapsort FROM TABLE
* ) GROUP BY _groupingmapsort
*/
object InsertMapSortInGroupingExpressions extends Rule[LogicalPlan] {
private def shouldAddMapSort(expr: Expression): Boolean = {
expr.dataType.existsRecursively(_.isInstanceOf[MapType])
}
import InsertMapSortExpression._

override def apply(plan: LogicalPlan): LogicalPlan = {
if (!plan.containsPattern(TreePattern.AGGREGATE)) {
if (!plan.containsPattern(AGGREGATE)) {
return plan
}
val shouldRewrite = plan.exists {
case agg: Aggregate if agg.groupingExpressions.exists(shouldAddMapSort) => true
case agg: Aggregate if agg.groupingExpressions.exists(mapTypeExistsRecursively) => true
case _ => false
}
if (!shouldRewrite) {
return plan
}

plan transformUpWithNewOutput {
case agg @ Aggregate(groupingExprs, aggregateExpressions, child, _)
if agg.groupingExpressions.exists(shouldAddMapSort) =>
case agg @ Aggregate(groupingExprs, aggregateExpressions, child, hint) =>
val exprToMapSort = new mutable.HashMap[Expression, NamedExpression]
val newGroupingKeys = groupingExprs.map { expr =>
val inserted = insertMapSortRecursively(expr)
Expand All @@ -77,15 +74,53 @@ object InsertMapSortInGroupingExpressions extends Rule[LogicalPlan] {
}.asInstanceOf[NamedExpression]
}
val newChild = Project(child.output ++ exprToMapSort.values, child)
val newAgg = Aggregate(newGroupingKeys, newAggregateExprs, newChild)
val newAgg = Aggregate(newGroupingKeys, newAggregateExprs, newChild, hint)
newAgg -> agg.output.zip(newAgg.output)
}
}
}

/**
* Adds [[MapSort]] to [[RepartitionByExpression]] expressions containing map columns,
* as the key/value pairs need to be in the correct order before repartitioning:
*
* SELECT * FROM TABLE DISTRIBUTE BY map_column =>
* SELECT * FROM TABLE DISTRIBUTE BY map_sort(map_column)
*/
object InsertMapSortInRepartitionExpressions extends Rule[LogicalPlan] {
import InsertMapSortExpression._

override def apply(plan: LogicalPlan): LogicalPlan = {
plan.transformUpWithPruning(_.containsPattern(REPARTITION_OPERATION)) {
case rep: RepartitionByExpression
if rep.partitionExpressions.exists(mapTypeExistsRecursively) =>
val exprToMapSort = new mutable.HashMap[Expression, Expression]
val newPartitionExprs = rep.partitionExpressions.map { expr =>
val inserted = insertMapSortRecursively(expr)
if (expr.ne(inserted)) {
exprToMapSort.getOrElseUpdate(expr.canonicalized, inserted)
} else {
expr
}
}
rep.copy(partitionExpressions = newPartitionExprs)
}
}
}

private[optimizer] object InsertMapSortExpression {

/**
* Inserts MapSort recursively taking into account when it is nested inside a struct or array.
* Returns true if the expression contains a [[MapType]] in DataType tree.
*/
private def insertMapSortRecursively(e: Expression): Expression = {
def mapTypeExistsRecursively(expr: Expression): Boolean = {
expr.dataType.existsRecursively(_.isInstanceOf[MapType])
}

/**
* Inserts [[MapSort]] recursively taking into account when it is nested inside a struct or array.
*/
def insertMapSortRecursively(e: Expression): Expression = {
e.dataType match {
case m: MapType =>
// Check if value type of MapType contains MapType (possibly nested)
Expand Down Expand Up @@ -122,5 +157,4 @@ object InsertMapSortInGroupingExpressions extends Rule[LogicalPlan] {
case _ => e
}
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -322,6 +322,7 @@ abstract class Optimizer(catalogManager: CatalogManager)
// so the grouping keys can only be attribute and literal which makes
// `InsertMapSortInGroupingExpressions` easy to insert `MapSort`.
InsertMapSortInGroupingExpressions,
InsertMapSortInRepartitionExpressions,
ComputeCurrentTime,
ReplaceCurrentLike(catalogManager),
SpecialDatetimeValues,
Expand Down
37 changes: 32 additions & 5 deletions sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -316,15 +316,15 @@ class DataFrameSuite extends QueryTest
exception = intercept[AnalysisException](df.repartition(5, col("v"))),
condition = "UNSUPPORTED_FEATURE.PARTITION_BY_VARIANT",
parameters = Map(
"expr" -> "v",
"expr" -> "\"v\"",
"dataType" -> "\"VARIANT\"")
)
// nested variant column
checkError(
exception = intercept[AnalysisException](df.repartition(5, col("s"))),
condition = "UNSUPPORTED_FEATURE.PARTITION_BY_VARIANT",
parameters = Map(
"expr" -> "s",
"expr" -> "\"s\"",
"dataType" -> "\"STRUCT<v: VARIANT NOT NULL>\"")
)
// variant producing expression
Expand All @@ -333,7 +333,7 @@ class DataFrameSuite extends QueryTest
intercept[AnalysisException](df.repartition(5, parse_json(col("id").cast("string")))),
condition = "UNSUPPORTED_FEATURE.PARTITION_BY_VARIANT",
parameters = Map(
"expr" -> "parse_json(CAST(id AS STRING))",
"expr" -> "\"parse_json(CAST(id AS STRING))\"",
"dataType" -> "\"VARIANT\"")
)
// Partitioning by non-variant column works
Expand All @@ -350,7 +350,7 @@ class DataFrameSuite extends QueryTest
exception = intercept[AnalysisException](sql("SELECT * FROM tv DISTRIBUTE BY v")),
condition = "UNSUPPORTED_FEATURE.PARTITION_BY_VARIANT",
parameters = Map(
"expr" -> "tv.v",
"expr" -> "\"v\"",
"dataType" -> "\"VARIANT\""),
context = ExpectedContext(
fragment = "DISTRIBUTE BY v",
Expand All @@ -361,7 +361,7 @@ class DataFrameSuite extends QueryTest
exception = intercept[AnalysisException](sql("SELECT * FROM tv DISTRIBUTE BY s")),
condition = "UNSUPPORTED_FEATURE.PARTITION_BY_VARIANT",
parameters = Map(
"expr" -> "tv.s",
"expr" -> "\"s\"",
"dataType" -> "\"STRUCT<v: VARIANT NOT NULL>\""),
context = ExpectedContext(
fragment = "DISTRIBUTE BY s",
Expand Down Expand Up @@ -428,6 +428,33 @@ class DataFrameSuite extends QueryTest
}
}

test("repartition by MapType") {
Seq("int", "long", "float", "double", "decimal(10, 2)", "string", "varchar(6)").foreach { dt =>
val df = spark.range(20)
.withColumn("c1",
when(col("id") % 3 === 1, typedLit(Map(1 -> 1)))
.when(col("id") % 3 === 2, typedLit(Map(1 -> 1, 2 -> 2)))
.otherwise(typedLit(Map(2 -> 2, 1 -> 1))).cast(s"map<$dt, $dt>"))
.withColumn("c2", typedLit(Map(1 -> null)).cast(s"map<$dt, $dt>"))
.withColumn("c3", lit(null).cast(s"map<$dt, $dt>"))

assertPartitionNumber(df.repartition(4, col("c1")), 2)
assertPartitionNumber(df.repartition(4, col("c2")), 1)
assertPartitionNumber(df.repartition(4, col("c3")), 1)
assertPartitionNumber(df.repartition(4, col("c1"), col("c2")), 2)
assertPartitionNumber(df.repartition(4, col("c1"), col("c3")), 2)
assertPartitionNumber(df.repartition(4, col("c1"), col("c2"), col("c3")), 2)
assertPartitionNumber(df.repartition(4, col("c2"), col("c3")), 2)
}
}

private def assertPartitionNumber(df: => DataFrame, max: Int): Unit = {
val dfGrouped = df.groupBy(spark_partition_id()).count()
// Result number of partition can be lower or equal to max,
// but no more than that.
assert(dfGrouped.count() <= max, dfGrouped.queryExecution.simpleString)
}

test("coalesce") {
intercept[IllegalArgumentException] {
testData.select("key").coalesce(0)
Expand Down

0 comments on commit a4f2870

Please sign in to comment.