Skip to content

Commit

Permalink
Rework group by map type to fix bind reference exception
Browse files Browse the repository at this point in the history
  • Loading branch information
ulysses-you committed Aug 2, 2024
1 parent aca0d24 commit 2935bdb
Show file tree
Hide file tree
Showing 5 changed files with 185 additions and 119 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.catalyst.optimizer

import scala.collection.mutable

import org.apache.spark.sql.catalyst.expressions.{Alias, ArrayTransform, CreateNamedStruct, Expression, GetStructField, If, IsNull, LambdaFunction, Literal, MapFromArrays, MapKeys, MapSort, MapValues, NamedExpression, NamedLambdaVariable}
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan}
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.types.{ArrayType, MapType, StructType}
import org.apache.spark.util.ArrayImplicits.SparkArrayOps

/**
* Adds [[MapSort]] to group expressions containing map columns, as the key/value paris need to be
* in the correct order before grouping:
*
* SELECT COUNT(*) FROM TABLE GROUP BY map_column =>
* SELECT COUNT(*) FROM TABLE GROUP BY map_sort(map_column)
*
* SELECT map_column, COUNT(*) FROM TABLE GROUP BY map_column =>
* SELECT map_sort(map_column) as map_column, COUNT(*) FROM TABLE GROUP BY map_sort(map_column)
*
* SELECT map_expr as c, COUNT(*) FROM TABLE GROUP BY map_expr =>
* SELECT map_sort(map_expr) as c, COUNT(*) FROM TABLE GROUP BY map_sort(map_expr)
*/
object AddMapSortInAggregate extends Rule[LogicalPlan] {
private def shouldAddMapSort(expr: Expression): Boolean = {
expr.dataType.existsRecursively(_.isInstanceOf[MapType]) && !expr.isInstanceOf[MapSort]
}

override def apply(plan: LogicalPlan): LogicalPlan = {
val shouldRewrite = plan.exists {
case agg: Aggregate if agg.groupingExpressions.exists(shouldAddMapSort) => true
case _ => false
}
if (!shouldRewrite) {
return plan
}

plan transformUpWithNewOutput {
case agg @ Aggregate(groupingExpr, aggregateExpressions, child) =>
val exprToMapSort = new mutable.HashMap[Expression, Expression]
val newGroupingKeys = groupingExpr.map(insertMapSortRecursively(_, exprToMapSort))
if (exprToMapSort.isEmpty) {
agg -> agg.output.zip(agg.output)
} else {
val newAggregateExprs = aggregateExpressions.map { namedExpr =>
val rewritten = namedExpr.transformUp {
case e => exprToMapSort.getOrElse(e.canonicalized, e)
}
if (namedExpr.eq(rewritten)) {
namedExpr
} else {
rewritten match {
case e @ MapSort(named: NamedExpression) => Alias(e, named.name)()
case other => other.asInstanceOf[NamedExpression]
}
}
}
val newAgg = Aggregate(newGroupingKeys, newAggregateExprs, child)
newAgg -> agg.output.zip(newAgg.output)
}
}
}

/**
* Inserts MapSort recursively taking into account when it is nested inside a struct or array.
*/
private def insertMapSortRecursively(
e: Expression,
exprToMapSort: mutable.HashMap[Expression, Expression]): Expression = {
e.dataType match {
case m: MapType =>
// Check if value type of MapType contains MapType (possibly nested)
// and special handle this case.
val mapSortExpr = if (m.valueType.existsRecursively(_.isInstanceOf[MapType])) {
MapFromArrays(MapKeys(e), insertMapSortRecursively(MapValues(e), exprToMapSort))
} else {
e
}
exprToMapSort.getOrElseUpdate(e.canonicalized, MapSort(mapSortExpr))

case StructType(fields)
if fields.exists(_.dataType.existsRecursively(_.isInstanceOf[MapType])) =>
val struct = CreateNamedStruct(fields.zipWithIndex.flatMap { case (f, i) =>
Seq(Literal(f.name), insertMapSortRecursively(
GetStructField(e, i, Some(f.name)), exprToMapSort))
}.toImmutableArraySeq)
if (struct.valExprs.forall(_.isInstanceOf[GetStructField])) {
// No field needs MapSort processing, just return the original expression.
e
} else if (e.nullable) {
If(IsNull(e), Literal(null, struct.dataType), struct)
} else {
struct
}

case ArrayType(et, containsNull) if et.existsRecursively(_.isInstanceOf[MapType]) =>
val param = NamedLambdaVariable("x", et, containsNull)
val funcBody = insertMapSortRecursively(param, exprToMapSort)
ArrayTransform(e, LambdaFunction(funcBody, Seq(param)))

case _ => e
}
}
}

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ abstract class Optimizer(catalogManager: CatalogManager)
}

val batches = (
Batch("Finish Analysis", Once, FinishAnalysis) ::
Batch("Finish Analysis", FixedPoint(1), FinishAnalysis) ::
// We must run this batch after `ReplaceExpressions`, as `RuntimeReplaceable` expression
// may produce `With` expressions that need to be rewritten.
Batch("Rewrite With expression", Once, RewriteWithExpression) ::
Expand Down Expand Up @@ -246,8 +246,6 @@ abstract class Optimizer(catalogManager: CatalogManager)
CollapseProject,
RemoveRedundantAliases,
RemoveNoopOperators) :+
Batch("InsertMapSortInGroupingExpressions", Once,
InsertMapSortInGroupingExpressions) :+
// This batch must be executed after the `RewriteSubquery` batch, which creates joins.
Batch("NormalizeFloatingNumbers", Once, NormalizeFloatingNumbers) :+
Batch("ReplaceUpdateFieldsExpression", Once, ReplaceUpdateFieldsExpression)
Expand Down Expand Up @@ -296,6 +294,7 @@ abstract class Optimizer(catalogManager: CatalogManager)
EliminateView,
ReplaceExpressions,
RewriteNonCorrelatedExists,
AddMapSortInAggregate,
PullOutGroupingExpressions,
ComputeCurrentTime,
ReplaceCurrentLike(catalogManager),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,6 @@ object RuleIdCollection {
"org.apache.spark.sql.catalyst.optimizer.EliminateSerialization" ::
"org.apache.spark.sql.catalyst.optimizer.EliminateWindowPartitions" ::
"org.apache.spark.sql.catalyst.optimizer.InferWindowGroupLimit" ::
"org.apache.spark.sql.catalyst.optimizer.InsertMapSortInGroupingExpressions" ::
"org.apache.spark.sql.catalyst.optimizer.LikeSimplification" ::
"org.apache.spark.sql.catalyst.optimizer.LimitPushDown" ::
"org.apache.spark.sql.catalyst.optimizer.LimitPushDownThroughWindow" ::
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2162,8 +2162,9 @@ class DataFrameAggregateSuite extends QueryTest
)
}

private def assertAggregateOnDataframe(df: DataFrame,
expected: Int, aggregateColumn: String): Unit = {
private def assertAggregateOnDataframe(
df: => DataFrame,
expected: Int): Unit = {
val configurations = Seq(
Seq.empty[(String, String)], // hash aggregate is used by default
Seq(SQLConf.CODEGEN_FACTORY_MODE.key -> "NO_CODEGEN",
Expand All @@ -2175,32 +2176,64 @@ class DataFrameAggregateSuite extends QueryTest
Seq("spark.sql.test.forceApplySortAggregate" -> "true")
)

for (conf <- configurations) {
withSQLConf(conf: _*) {
assert(createAggregate(df).count() == expected)
// Make tests faster
withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "3") {
for (conf <- configurations) {
withSQLConf(conf: _*) {
assert(df.count() == expected, df.queryExecution.simpleString)
}
}
}

def createAggregate(df: DataFrame): DataFrame = df.groupBy(aggregateColumn).agg(count("*"))
}

test("SPARK-47430 Support GROUP BY MapType") {
val numRows = 50

val dfSameInt = (0 until numRows)
.map(_ => Tuple1(Map(1 -> 1)))
.toDF("m0")
assertAggregateOnDataframe(dfSameInt, 1, "m0")

val dfSameFloat = (0 until numRows)
.map(i => Tuple1(Map(if (i % 2 == 0) 1 -> 0.0 else 1 -> -0.0 )))
.toDF("m0")
assertAggregateOnDataframe(dfSameFloat, 1, "m0")

val dfDifferent = (0 until numRows)
.map(i => Tuple1(Map(i -> i)))
.toDF("m0")
assertAggregateOnDataframe(dfDifferent, numRows, "m0")
def genMapData(dataType: String): String = {
s"""
|case when id % 4 == 0 then map()
|when id % 4 == 1 then map(cast(0 as $dataType), cast(0 as $dataType))
|when id % 4 == 2 then map(cast(0 as $dataType), cast(0 as $dataType),
| cast(1 as $dataType), cast(1 as $dataType))
|else map(cast(1 as $dataType), cast(1 as $dataType),
| cast(0 as $dataType), cast(0 as $dataType))
|end
|""".stripMargin
}
Seq("int", "long", "float", "double", "decimal(10, 2)", "string", "varchar(6)").foreach { dt =>
withTempView("v") {
spark.range(20)
.selectExpr(
s"cast(1 as $dt) as c1",
s"${genMapData(dt)} as c2",
"map(c1, null) as c3",
s"cast(null as map<$dt, $dt>) as c4")
.createOrReplaceTempView("v")

assertAggregateOnDataframe(
spark.sql("SELECT count(*) FROM v GROUP BY c2"),
3)
assertAggregateOnDataframe(
spark.sql("SELECT c2, count(*) FROM v GROUP BY c2"),
3)
assertAggregateOnDataframe(
spark.sql("SELECT c1, c2, count(*) FROM v GROUP BY c1, c2"),
3)
assertAggregateOnDataframe(
spark.sql("SELECT map(c1, c1) FROM v GROUP BY map(c1, c1)"),
1)
assertAggregateOnDataframe(
spark.sql("SELECT map(c1, c1), count(*) FROM v GROUP BY map(c1, c1)"),
1)
assertAggregateOnDataframe(
spark.sql("SELECT c3, count(*) FROM v GROUP BY c3"),
1)
assertAggregateOnDataframe(
spark.sql("SELECT c4, count(*) FROM v GROUP BY c4"),
1)
assertAggregateOnDataframe(
spark.sql("SELECT c1, c2, c3, c4, count(*) FROM v GROUP BY c1, c2, c3, c4"),
3)
}
}
}

test("SPARK-46536 Support GROUP BY CalendarIntervalType") {
Expand All @@ -2209,12 +2242,16 @@ class DataFrameAggregateSuite extends QueryTest
val dfSame = (0 until numRows)
.map(_ => Tuple1(new CalendarInterval(1, 2, 3)))
.toDF("c0")
assertAggregateOnDataframe(dfSame, 1, "c0")
.groupBy($"c0")
.count()
assertAggregateOnDataframe(dfSame, 1)

val dfDifferent = (0 until numRows)
.map(i => Tuple1(new CalendarInterval(i, i, i)))
.toDF("c0")
assertAggregateOnDataframe(dfDifferent, numRows, "c0")
.groupBy($"c0")
.count()
assertAggregateOnDataframe(dfDifferent, numRows)
}

test("SPARK-46779: Group by subquery with a cached relation") {
Expand Down

0 comments on commit 2935bdb

Please sign in to comment.