diff --git a/docs/supported_ops.md b/docs/supported_ops.md
index f8baa9f5e01..3b5c4e036a2 100644
--- a/docs/supported_ops.md
+++ b/docs/supported_ops.md
@@ -556,7 +556,7 @@ Accelerator supports are described below.
S |
NS |
NS |
-PS not allowed for grouping expressions; UTC is only supported TZ for child TIMESTAMP; unsupported child types BINARY, CALENDAR, UDT |
+PS not allowed for grouping expressions if containing Struct as child; UTC is only supported TZ for child TIMESTAMP; unsupported child types BINARY, CALENDAR, UDT |
PS not allowed for grouping expressions; UTC is only supported TZ for child TIMESTAMP; unsupported child types BINARY, CALENDAR, UDT |
PS not allowed for grouping expressions if containing Array or Map as child; UTC is only supported TZ for child TIMESTAMP; unsupported child types BINARY, CALENDAR, UDT |
NS |
@@ -748,7 +748,7 @@ Accelerator supports are described below.
S |
S |
NS |
-PS Round-robin partitioning is not supported if spark.sql.execution.sortBeforeRepartition is true; UTC is only supported TZ for child TIMESTAMP; unsupported child types CALENDAR, UDT |
+PS UTC is only supported TZ for child TIMESTAMP; unsupported child types CALENDAR, UDT |
PS Round-robin partitioning is not supported if spark.sql.execution.sortBeforeRepartition is true; UTC is only supported TZ for child TIMESTAMP; unsupported child types CALENDAR, UDT |
PS Round-robin partitioning is not supported for nested structs if spark.sql.execution.sortBeforeRepartition is true; UTC is only supported TZ for child TIMESTAMP; unsupported child types CALENDAR, UDT |
NS |
@@ -7854,45 +7854,45 @@ are limited.
None |
project |
input |
- |
- |
- |
- |
- |
S |
S |
- |
- |
- |
- |
- |
- |
- |
- |
- |
- |
- |
+S |
+S |
+S |
+S |
+S |
+S |
+PS UTC is only supported TZ for TIMESTAMP |
+S |
+S |
+S |
+S |
+S |
+PS UTC is only supported TZ for child TIMESTAMP |
+PS UTC is only supported TZ for child TIMESTAMP |
+PS UTC is only supported TZ for child TIMESTAMP |
+S |
result |
- |
- |
- |
- |
- |
S |
S |
- |
- |
- |
- |
- |
- |
- |
- |
- |
- |
- |
+S |
+S |
+S |
+S |
+S |
+S |
+PS UTC is only supported TZ for TIMESTAMP |
+S |
+S |
+S |
+S |
+S |
+PS UTC is only supported TZ for child TIMESTAMP |
+PS UTC is only supported TZ for child TIMESTAMP |
+PS UTC is only supported TZ for child TIMESTAMP |
+S |
KnownNotNull |
@@ -18852,9 +18852,9 @@ as `a` don't show up in the table. They are controlled by the rules for
S |
NS |
NS |
+PS UTC is only supported TZ for child TIMESTAMP; unsupported child types BINARY, CALENDAR, MAP, UDT |
NS |
-NS |
-PS UTC is only supported TZ for child TIMESTAMP; unsupported child types BINARY, CALENDAR, ARRAY, MAP, UDT |
+PS UTC is only supported TZ for child TIMESTAMP; unsupported child types BINARY, CALENDAR, MAP, UDT |
NS |
diff --git a/integration_tests/src/main/python/hash_aggregate_test.py b/integration_tests/src/main/python/hash_aggregate_test.py
index ffa22896d6a..df9dafc1849 100644
--- a/integration_tests/src/main/python/hash_aggregate_test.py
+++ b/integration_tests/src/main/python/hash_aggregate_test.py
@@ -116,6 +116,19 @@
('b', FloatGen(nullable=(True, 10.0), special_cases=[(float('nan'), 10.0)])),
('c', LongGen())]
+# grouping single-level lists
+_grpkey_list_with_non_nested_children = [[('a', RepeatSeqGen(ArrayGen(data_gen), length=3)),
+ ('b', IntegerGen())] for data_gen in all_basic_gens + decimal_gens]
+
+#grouping mutliple-level structs with arrays
+_grpkey_nested_structs_with_array_basic_child = [
+ ('a', RepeatSeqGen(StructGen([
+ ['aa', IntegerGen()],
+ ['ab', ArrayGen(IntegerGen())]]),
+ length=20)),
+ ('b', IntegerGen()),
+ ('c', NullGen())]
+
_nan_zero_float_special_cases = [
(float('nan'), 5.0),
(NEG_FLOAT_NAN_MIN_VALUE, 5.0),
@@ -318,7 +331,7 @@ def test_hash_reduction_decimal_overflow_sum(precision):
# some optimizations are conspiring against us.
conf = {'spark.rapids.sql.batchSizeBytes': '128m'})
-@pytest.mark.parametrize('data_gen', [_longs_with_nulls], ids=idfn)
+@pytest.mark.parametrize('data_gen', [_grpkey_nested_structs_with_array_basic_child, _longs_with_nulls] + _grpkey_list_with_non_nested_children, ids=idfn)
def test_hash_grpby_sum_count_action(data_gen):
assert_gpu_and_cpu_row_counts_equal(
lambda spark: gen_df(spark, data_gen, length=100).groupby('a').agg(f.sum('b'))
diff --git a/integration_tests/src/main/python/repart_test.py b/integration_tests/src/main/python/repart_test.py
index 7b77b7be426..b12a680d3eb 100644
--- a/integration_tests/src/main/python/repart_test.py
+++ b/integration_tests/src/main/python/repart_test.py
@@ -214,10 +214,23 @@ def test_round_robin_sort_fallback(data_gen):
lambda spark : gen_df(spark, data_gen).withColumn('extra', lit(1)).repartition(13),
'ShuffleExchangeExec')
+@allow_non_gpu("ProjectExec", "ShuffleExchangeExec")
+@ignore_order(local=True) # To avoid extra data shuffle by 'sort on Spark' for this repartition test.
+@pytest.mark.parametrize('num_parts', [2, 10, 17, 19, 32], ids=idfn)
+@pytest.mark.parametrize('gen', [([('ag', ArrayGen(StructGen([('b1', long_gen)])))], ['ag'])], ids=idfn)
+def test_hash_repartition_exact_fallback(gen, num_parts):
+ data_gen = gen[0]
+ part_on = gen[1]
+ assert_gpu_fallback_collect(
+ lambda spark : gen_df(spark, data_gen, length=1024) \
+ .repartition(num_parts, *part_on) \
+ .withColumn('id', f.spark_partition_id()) \
+ .selectExpr('*'), "ShuffleExchangeExec")
+
@ignore_order(local=True) # To avoid extra data shuffle by 'sort on Spark' for this repartition test.
@pytest.mark.parametrize('num_parts', [1, 2, 10, 17, 19, 32], ids=idfn)
@pytest.mark.parametrize('gen', [
- ([('a', boolean_gen)], ['a']),
+ ([('a', boolean_gen)], ['a']),
([('a', byte_gen)], ['a']),
([('a', short_gen)], ['a']),
([('a', int_gen)], ['a']),
diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala
index 9ec6448df04..2341fea9b9d 100644
--- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala
+++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala
@@ -1560,9 +1560,7 @@ object GpuOverrides extends Logging {
}),
expr[KnownFloatingPointNormalized](
"Tag to prevent redundant normalization",
- ExprChecks.unaryProjectInputMatchesOutput(
- TypeSig.DOUBLE + TypeSig.FLOAT,
- TypeSig.DOUBLE + TypeSig.FLOAT),
+ ExprChecks.unaryProjectInputMatchesOutput(TypeSig.all, TypeSig.all),
(a, conf, p, r) => new UnaryExprMeta[KnownFloatingPointNormalized](a, conf, p, r) {
override def convertToGpu(child: Expression): GpuExpression =
GpuKnownFloatingPointNormalized(child)
@@ -3567,11 +3565,26 @@ object GpuOverrides extends Logging {
// This needs to match what murmur3 supports.
PartChecks(RepeatingParamCheck("hash_key",
(TypeSig.commonCudfTypes + TypeSig.NULL + TypeSig.DECIMAL_128 +
- TypeSig.STRUCT).nested(), TypeSig.all)),
+ TypeSig.STRUCT + TypeSig.ARRAY).nested(),
+ TypeSig.all)
+ ),
(hp, conf, p, r) => new PartMeta[HashPartitioning](hp, conf, p, r) {
override val childExprs: Seq[BaseExprMeta[_]] =
hp.expressions.map(GpuOverrides.wrapExpr(_, conf, Some(this)))
+ override def tagPartForGpu(): Unit = {
+ val arrayWithStructsHashing = hp.expressions.exists(e =>
+ TrampolineUtil.dataTypeExistsRecursively(e.dataType,
+ dt => dt match {
+ case ArrayType(_: StructType, _) => true
+ case _ => false
+ })
+ )
+ if (arrayWithStructsHashing) {
+ willNotWorkOnGpu("hashing arrays with structs is not supported")
+ }
+ }
+
override def convertToGpu(): GpuPartitioning =
GpuHashPartitioning(childExprs.map(_.convertToGpu()), hp.numPartitions)
}),
@@ -3820,7 +3833,7 @@ object GpuOverrides extends Logging {
.withPsNote(TypeEnum.STRUCT, "Round-robin partitioning is not supported for nested " +
s"structs if ${SQLConf.SORT_BEFORE_REPARTITION.key} is true")
.withPsNote(
- Seq(TypeEnum.ARRAY, TypeEnum.MAP),
+ Seq(TypeEnum.MAP),
"Round-robin partitioning is not supported if " +
s"${SQLConf.SORT_BEFORE_REPARTITION.key} is true"),
TypeSig.all),
@@ -3882,10 +3895,12 @@ object GpuOverrides extends Logging {
"The backend for hash based aggregations",
ExecChecks(
(TypeSig.commonCudfTypes + TypeSig.NULL + TypeSig.DECIMAL_128 +
- TypeSig.MAP + TypeSig.ARRAY + TypeSig.STRUCT)
+ TypeSig.MAP + TypeSig.STRUCT + TypeSig.ARRAY)
.nested()
- .withPsNote(Seq(TypeEnum.ARRAY, TypeEnum.MAP),
+ .withPsNote(TypeEnum.MAP,
"not allowed for grouping expressions")
+ .withPsNote(TypeEnum.ARRAY,
+ "not allowed for grouping expressions if containing Struct as child")
.withPsNote(TypeEnum.STRUCT,
"not allowed for grouping expressions if containing Array or Map as child"),
TypeSig.all),
diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/aggregate.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/aggregate.scala
index a3cd116927f..44f202db6c3 100644
--- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/aggregate.scala
+++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/aggregate.scala
@@ -42,7 +42,7 @@ import org.apache.spark.sql.execution.{ExplainUtils, SortExec, SparkPlan}
import org.apache.spark.sql.execution.aggregate.{BaseAggregateExec, HashAggregateExec, ObjectHashAggregateExec, SortAggregateExec}
import org.apache.spark.sql.rapids.{CpuToGpuAggregateBufferConverter, CudfAggregate, GpuAggregateExpression, GpuToCpuAggregateBufferConverter}
import org.apache.spark.sql.rapids.execution.{GpuShuffleMeta, TrampolineUtil}
-import org.apache.spark.sql.types.{ArrayType, DataType, MapType}
+import org.apache.spark.sql.types.{ArrayType, DataType, MapType, StructType}
import org.apache.spark.sql.vectorized.ColumnarBatch
object AggregateUtils {
@@ -847,13 +847,27 @@ abstract class GpuBaseAggregateMeta[INPUT <: SparkPlan](
groupingExpressions ++ aggregateExpressions ++ aggregateAttributes ++ resultExpressions
override def tagPlanForGpu(): Unit = {
- // We don't support Arrays and Maps as GroupBy keys yet, even they are nested in Structs. So,
+ // We don't support Maps as GroupBy keys yet, even if they are nested in Structs. So,
// we need to run recursive type check on the structs.
- val arrayOrMapGroupings = agg.groupingExpressions.exists(e =>
+ val mapGroupings = agg.groupingExpressions.exists(e =>
TrampolineUtil.dataTypeExistsRecursively(e.dataType,
- dt => dt.isInstanceOf[ArrayType] || dt.isInstanceOf[MapType]))
- if (arrayOrMapGroupings) {
- willNotWorkOnGpu("ArrayTypes or MapTypes in grouping expressions are not supported")
+ dt => dt.isInstanceOf[MapType]))
+ if (mapGroupings) {
+ willNotWorkOnGpu("MapTypes in grouping expressions are not supported")
+ }
+
+ // We support Arrays as grouping expression but not if the child is a struct. So we need to
+ // run recursive type check on the lists of structs
+ val arrayWithStructsGroupings = agg.groupingExpressions.exists(e =>
+ TrampolineUtil.dataTypeExistsRecursively(e.dataType,
+ dt => dt match {
+ case ArrayType(_: StructType, _) => true
+ case _ => false
+ })
+ )
+ if (arrayWithStructsGroupings) {
+ willNotWorkOnGpu("ArrayTypes with Struct children in grouping expressions are not " +
+ "supported")
}
tagForReplaceMode()
diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/AggregateFunctions.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/AggregateFunctions.scala
index f9dd411cbec..16573bd33e5 100644
--- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/AggregateFunctions.scala
+++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/AggregateFunctions.scala
@@ -559,7 +559,7 @@ case class GpuBasicMin(child: Expression) extends GpuMin(child)
*/
case class GpuFloatMin(child: Expression) extends GpuMin(child)
with GpuReplaceWindowFunction {
-
+
override val dataType: DataType = child.dataType match {
case FloatType | DoubleType => child.dataType
case t => throw new IllegalStateException(s"child type $t is not FloatType or DoubleType")
@@ -606,7 +606,7 @@ case class GpuFloatMin(child: Expression) extends GpuMin(child)
// Else return the min value
override lazy val postUpdate: Seq[Expression] = Seq(
GpuIf(
- updateAllNansOrNulls.attr,
+ updateAllNansOrNulls.attr,
GpuIf(
updateHasNan.attr, GpuLiteral(nan, dataType), GpuLiteral(null, dataType)
),
@@ -668,7 +668,7 @@ object GpuMax {
abstract class GpuMax(child: Expression) extends GpuAggregateFunction
with GpuBatchedRunningWindowWithFixer
with GpuAggregateWindowFunction
- with GpuRunningWindowFunction
+ with GpuRunningWindowFunction
with Serializable {
override lazy val initialValues: Seq[GpuLiteral] = Seq(GpuLiteral(null, child.dataType))
override lazy val inputProjection: Seq[Expression] = Seq(child)
@@ -730,7 +730,7 @@ case class GpuBasicMax(child: Expression) extends GpuMax(child)
* column `isNan`. If any value in this column is true, return `Nan`,
* Else, return what `GpuBasicMax` returns.
*/
-case class GpuFloatMax(child: Expression) extends GpuMax(child)
+case class GpuFloatMax(child: Expression) extends GpuMax(child)
with GpuReplaceWindowFunction{
override val dataType: DataType = child.dataType match {
@@ -756,13 +756,13 @@ case class GpuFloatMax(child: Expression) extends GpuMax(child)
override lazy val updateAggregates: Seq[CudfAggregate] = Seq(updateMaxVal, updateIsNan)
// If there is `Nan` value in the target column, return `Nan`
// else return what the `CudfMax` returns
- override lazy val postUpdate: Seq[Expression] =
+ override lazy val postUpdate: Seq[Expression] =
Seq(
GpuIf(updateIsNan.attr, GpuLiteral(nan, dataType), updateMaxVal.attr)
)
// Same logic as the `inputProjection` stage.
- override lazy val preMerge: Seq[Expression] =
+ override lazy val preMerge: Seq[Expression] =
Seq(evaluateExpression, GpuIsNan(evaluateExpression))
// Same logic as the `updateAggregates` stage.
override lazy val mergeAggregates: Seq[CudfAggregate] = Seq(mergeMaxVal, mergeIsNan)
diff --git a/tools/generated_files/supportedExprs.csv b/tools/generated_files/supportedExprs.csv
index a90a3772580..de7d9bb117b 100644
--- a/tools/generated_files/supportedExprs.csv
+++ b/tools/generated_files/supportedExprs.csv
@@ -264,8 +264,8 @@ IsNotNull,S,`isnotnull`,None,project,input,S,S,S,S,S,S,S,S,PS,S,S,S,S,NS,PS,PS,P
IsNotNull,S,`isnotnull`,None,project,result,S,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA
IsNull,S,`isnull`,None,project,input,S,S,S,S,S,S,S,S,PS,S,S,S,S,NS,PS,PS,PS,NS
IsNull,S,`isnull`,None,project,result,S,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA
-KnownFloatingPointNormalized,S, ,None,project,input,NA,NA,NA,NA,NA,S,S,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA
-KnownFloatingPointNormalized,S, ,None,project,result,NA,NA,NA,NA,NA,S,S,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA
+KnownFloatingPointNormalized,S, ,None,project,input,S,S,S,S,S,S,S,S,PS,S,S,S,S,S,PS,PS,PS,S
+KnownFloatingPointNormalized,S, ,None,project,result,S,S,S,S,S,S,S,S,PS,S,S,S,S,S,PS,PS,PS,S
KnownNotNull,S, ,None,project,input,S,S,S,S,S,S,S,S,PS,S,S,NS,S,S,PS,PS,PS,NS
KnownNotNull,S, ,None,project,result,S,S,S,S,S,S,S,S,PS,S,S,NS,S,S,PS,PS,PS,NS
Lag,S,`lag`,None,window,input,S,S,S,S,S,S,S,S,PS,S,S,S,NS,NS,PS,NS,PS,NS