Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-50525][SQL] Define InsertMapSortInRepartitionExpressions Optimizer Rule #49144

Closed
wants to merge 6 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -884,7 +884,7 @@ trait CheckAnalysis extends PredicateHelper with LookupCatalog with QueryErrorsB
o.failAnalysis(
errorClass = "UNSUPPORTED_FEATURE.PARTITION_BY_VARIANT",
messageParameters = Map(
"expr" -> variantExpr.sql,
"expr" -> toSQLExpr(variantExpr),
"dataType" -> toSQLType(variantExpr.dataType)))

case o if o.expressions.exists(!_.deterministic) &&
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,41 +20,38 @@ package org.apache.spark.sql.catalyst.optimizer
import scala.collection.mutable

import org.apache.spark.sql.catalyst.expressions.{Alias, ArrayTransform, CreateNamedStruct, Expression, GetStructField, If, IsNull, LambdaFunction, Literal, MapFromArrays, MapKeys, MapSort, MapValues, NamedExpression, NamedLambdaVariable}
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan, Project}
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan, Project, RepartitionByExpression}
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.catalyst.trees.TreePattern
import org.apache.spark.sql.catalyst.trees.TreePattern.{AGGREGATE, REPARTITION_OPERATION}
import org.apache.spark.sql.types.{ArrayType, MapType, StructType}
import org.apache.spark.util.ArrayImplicits.SparkArrayOps

/**
* Adds [[MapSort]] to group expressions containing map columns, as the key/value pairs need to be
* in the correct order before grouping:
* Adds [[MapSort]] to [[Aggregate]] expressions containing map columns,
* as the key/value pairs need to be in the correct order before grouping:
*
* SELECT map_column, COUNT(*) FROM TABLE GROUP BY map_column =>
* SELECT map_column, COUNT(*) FROM TABLE GROUP BY map_column =>
* SELECT _groupingmapsort as map_column, COUNT(*) FROM (
* SELECT map_sort(map_column) as _groupingmapsort FROM TABLE
* ) GROUP BY _groupingmapsort
*/
object InsertMapSortInGroupingExpressions extends Rule[LogicalPlan] {
private def shouldAddMapSort(expr: Expression): Boolean = {
expr.dataType.existsRecursively(_.isInstanceOf[MapType])
}
import InsertMapSortExpression._

override def apply(plan: LogicalPlan): LogicalPlan = {
if (!plan.containsPattern(TreePattern.AGGREGATE)) {
if (!plan.containsPattern(AGGREGATE)) {
return plan
}
val shouldRewrite = plan.exists {
case agg: Aggregate if agg.groupingExpressions.exists(shouldAddMapSort) => true
case agg: Aggregate if agg.groupingExpressions.exists(mapTypeExistsRecursively) => true
case _ => false
}
if (!shouldRewrite) {
return plan
}

plan transformUpWithNewOutput {
case agg @ Aggregate(groupingExprs, aggregateExpressions, child, _)
if agg.groupingExpressions.exists(shouldAddMapSort) =>
case agg @ Aggregate(groupingExprs, aggregateExpressions, child, hint) =>
val exprToMapSort = new mutable.HashMap[Expression, NamedExpression]
val newGroupingKeys = groupingExprs.map { expr =>
val inserted = insertMapSortRecursively(expr)
Expand All @@ -77,15 +74,53 @@ object InsertMapSortInGroupingExpressions extends Rule[LogicalPlan] {
}.asInstanceOf[NamedExpression]
}
val newChild = Project(child.output ++ exprToMapSort.values, child)
val newAgg = Aggregate(newGroupingKeys, newAggregateExprs, newChild)
val newAgg = Aggregate(newGroupingKeys, newAggregateExprs, newChild, hint)
newAgg -> agg.output.zip(newAgg.output)
}
}
}

/**
* Adds [[MapSort]] to [[RepartitionByExpression]] expressions containing map columns,
* as the key/value pairs need to be in the correct order before repartitioning:
*
* SELECT * FROM TABLE DISTRIBUTE BY map_column =>
* SELECT * FROM TABLE DISTRIBUTE BY map_sort(map_column)
*/
object InsertMapSortInRepartitionExpressions extends Rule[LogicalPlan] {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we combine these two rules so that we only need to traverse the plan once?

Copy link
Contributor Author

@ostronaut ostronaut Jan 6, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

initially i wanted to do the same, but logic for InsertMapSortInGroupingExpressions and InsertMapSortInRepartitionExpressions is quite different: Grouping produces a new output after applying the changes, while Repartition only updates existing RepartitionByExpression by replacing partitionExpressions.
Also, there is a dependency between InsertMapSortInGroupingExpressions and PullOutGroupingExpressions, as mentioned in this comment.

For those reasons i decided to split them into separate Rules. But if you think performance saving from reduced traverse will be significant, we can combine those.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's fine to use transformUpWithNewOutput for both. If we hit RepartitionByExpression, we return Nil as the new output.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i think this will make things more complex, while there is no need to return new output for RepartitionByExpression. Also, to prevent traverse on every plan we have added two conditions in InsertMapSortInRepartitionExpressions:

  1. _.containsPattern(REPARTITION_OPERATION) as cond to transformUpWithPruning.
  2. if rep.partitionExpressions.exists(mapTypeExistsRecursively) in case matching.

So i would keep these two independent from each other.

import InsertMapSortExpression._

override def apply(plan: LogicalPlan): LogicalPlan = {
plan.transformUpWithPruning(_.containsPattern(REPARTITION_OPERATION)) {
case rep: RepartitionByExpression
if rep.partitionExpressions.exists(mapTypeExistsRecursively) =>
val exprToMapSort = new mutable.HashMap[Expression, Expression]
val newPartitionExprs = rep.partitionExpressions.map { expr =>
val inserted = insertMapSortRecursively(expr)
if (expr.ne(inserted)) {
exprToMapSort.getOrElseUpdate(expr.canonicalized, inserted)
} else {
expr
}
}
rep.copy(partitionExpressions = newPartitionExprs)
}
}
}

private[optimizer] object InsertMapSortExpression {

/**
* Inserts MapSort recursively taking into account when it is nested inside a struct or array.
* Returns true if the expression contains a [[MapType]] in DataType tree.
*/
private def insertMapSortRecursively(e: Expression): Expression = {
def mapTypeExistsRecursively(expr: Expression): Boolean = {
expr.dataType.existsRecursively(_.isInstanceOf[MapType])
}

/**
* Inserts [[MapSort]] recursively taking into account when it is nested inside a struct or array.
*/
def insertMapSortRecursively(e: Expression): Expression = {
e.dataType match {
case m: MapType =>
// Check if value type of MapType contains MapType (possibly nested)
Expand Down Expand Up @@ -122,5 +157,4 @@ object InsertMapSortInGroupingExpressions extends Rule[LogicalPlan] {
case _ => e
}
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -321,6 +321,7 @@ abstract class Optimizer(catalogManager: CatalogManager)
// so the grouping keys can only be attribute and literal which makes
// `InsertMapSortInGroupingExpressions` easy to insert `MapSort`.
InsertMapSortInGroupingExpressions,
InsertMapSortInRepartitionExpressions,
ComputeCurrentTime,
ReplaceCurrentLike(catalogManager),
SpecialDatetimeValues,
Expand Down
37 changes: 32 additions & 5 deletions sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -316,15 +316,15 @@ class DataFrameSuite extends QueryTest
exception = intercept[AnalysisException](df.repartition(5, col("v"))),
condition = "UNSUPPORTED_FEATURE.PARTITION_BY_VARIANT",
parameters = Map(
"expr" -> "v",
"expr" -> "\"v\"",
"dataType" -> "\"VARIANT\"")
)
// nested variant column
checkError(
exception = intercept[AnalysisException](df.repartition(5, col("s"))),
condition = "UNSUPPORTED_FEATURE.PARTITION_BY_VARIANT",
parameters = Map(
"expr" -> "s",
"expr" -> "\"s\"",
"dataType" -> "\"STRUCT<v: VARIANT NOT NULL>\"")
)
// variant producing expression
Expand All @@ -333,7 +333,7 @@ class DataFrameSuite extends QueryTest
intercept[AnalysisException](df.repartition(5, parse_json(col("id").cast("string")))),
condition = "UNSUPPORTED_FEATURE.PARTITION_BY_VARIANT",
parameters = Map(
"expr" -> "parse_json(CAST(id AS STRING))",
"expr" -> "\"parse_json(CAST(id AS STRING))\"",
"dataType" -> "\"VARIANT\"")
)
// Partitioning by non-variant column works
Expand All @@ -350,7 +350,7 @@ class DataFrameSuite extends QueryTest
exception = intercept[AnalysisException](sql("SELECT * FROM tv DISTRIBUTE BY v")),
condition = "UNSUPPORTED_FEATURE.PARTITION_BY_VARIANT",
parameters = Map(
"expr" -> "tv.v",
"expr" -> "\"v\"",
"dataType" -> "\"VARIANT\""),
context = ExpectedContext(
fragment = "DISTRIBUTE BY v",
Expand All @@ -361,7 +361,7 @@ class DataFrameSuite extends QueryTest
exception = intercept[AnalysisException](sql("SELECT * FROM tv DISTRIBUTE BY s")),
condition = "UNSUPPORTED_FEATURE.PARTITION_BY_VARIANT",
parameters = Map(
"expr" -> "tv.s",
"expr" -> "\"s\"",
"dataType" -> "\"STRUCT<v: VARIANT NOT NULL>\""),
context = ExpectedContext(
fragment = "DISTRIBUTE BY s",
Expand Down Expand Up @@ -428,6 +428,33 @@ class DataFrameSuite extends QueryTest
}
}

test("repartition by MapType") {
Seq("int", "long", "float", "double", "decimal(10, 2)", "string", "varchar(6)").foreach { dt =>
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This newly added test case fails at NON-ANSI mode at the test case, "decimal(10, 2)".

val df = spark.range(20)
.withColumn("c1",
when(col("id") % 3 === 1, typedLit(Map(1 -> 1)))
.when(col("id") % 3 === 2, typedLit(Map(1 -> 1, 2 -> 2)))
.otherwise(typedLit(Map(2 -> 2, 1 -> 1))).cast(s"map<$dt, $dt>"))
.withColumn("c2", typedLit(Map(1 -> null)).cast(s"map<$dt, $dt>"))
.withColumn("c3", lit(null).cast(s"map<$dt, $dt>"))

assertPartitionNumber(df.repartition(4, col("c1")), 2)
assertPartitionNumber(df.repartition(4, col("c2")), 1)
assertPartitionNumber(df.repartition(4, col("c3")), 1)
assertPartitionNumber(df.repartition(4, col("c1"), col("c2")), 2)
assertPartitionNumber(df.repartition(4, col("c1"), col("c3")), 2)
assertPartitionNumber(df.repartition(4, col("c1"), col("c2"), col("c3")), 2)
assertPartitionNumber(df.repartition(4, col("c2"), col("c3")), 2)
}
}

private def assertPartitionNumber(df: => DataFrame, max: Int): Unit = {
val dfGrouped = df.groupBy(spark_partition_id()).count()
// Result number of partition can be lower or equal to max,
// but no more than that.
assert(dfGrouped.count() <= max, dfGrouped.queryExecution.simpleString)
}

test("coalesce") {
intercept[IllegalArgumentException] {
testData.select("key").coalesce(0)
Expand Down
Loading