diff --git a/native-sql-engine/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala b/native-sql-engine/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala index cdd256877..8df53f7ab 100644 --- a/native-sql-engine/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala +++ b/native-sql-engine/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala @@ -69,38 +69,40 @@ class DataFrameStatSuite extends QueryTest with SharedSparkSession { } test("randomSplit on reordered partitions") { + // TODO: Once fixed, will removed this config + withSQLConf("spark.oap.sql.columnar.enableComplexType" -> "false") { + def testNonOverlappingSplits(data: DataFrame): Unit = { + val splits = data.randomSplit(Array[Double](2, 3), seed = 1) + assert(splits.length == 2, "wrong number of splits") + + // Verify that the splits span the entire dataset + assert(splits.flatMap(_.collect()).toSet == data.collect().toSet) + + // Verify that the splits don't overlap + assert(splits(0).collect().toSeq.intersect(splits(1).collect().toSeq).isEmpty) + + // Verify that the results are deterministic across multiple runs + val firstRun = splits.toSeq.map(_.collect().toSeq) + val secondRun = data.randomSplit(Array[Double](2, 3), seed = 1).toSeq.map(_.collect().toSeq) + assert(firstRun == secondRun) + } - def testNonOverlappingSplits(data: DataFrame): Unit = { - val splits = data.randomSplit(Array[Double](2, 3), seed = 1) - assert(splits.length == 2, "wrong number of splits") - - // Verify that the splits span the entire dataset - assert(splits.flatMap(_.collect()).toSet == data.collect().toSet) - - // Verify that the splits don't overlap - assert(splits(0).collect().toSeq.intersect(splits(1).collect().toSeq).isEmpty) - - // Verify that the results are deterministic across multiple runs - val firstRun = splits.toSeq.map(_.collect().toSeq) - val secondRun = data.randomSplit(Array[Double](2, 3), seed = 1).toSeq.map(_.collect().toSeq) - assert(firstRun == secondRun) + // This test ensures that randomSplit does not create overlapping splits even when the + // underlying dataframe (such as the one below) doesn't guarantee a deterministic ordering of + // rows in each partition. + val dataWithInts = sparkContext.parallelize(1 to 600, 2) + .mapPartitions(scala.util.Random.shuffle(_)).toDF("int") + val dataWithMaps = sparkContext.parallelize(1 to 600, 2) + .map(i => (i, Map(i -> i.toString))) + .mapPartitions(scala.util.Random.shuffle(_)).toDF("int", "map") + val dataWithArrayOfMaps = sparkContext.parallelize(1 to 600, 2) + .map(i => (i, Array(Map(i -> i.toString)))) + .mapPartitions(scala.util.Random.shuffle(_)).toDF("int", "arrayOfMaps") + + testNonOverlappingSplits(dataWithInts) + testNonOverlappingSplits(dataWithMaps) + testNonOverlappingSplits(dataWithArrayOfMaps) } - - // This test ensures that randomSplit does not create overlapping splits even when the - // underlying dataframe (such as the one below) doesn't guarantee a deterministic ordering of - // rows in each partition. - val dataWithInts = sparkContext.parallelize(1 to 600, 2) - .mapPartitions(scala.util.Random.shuffle(_)).toDF("int") - val dataWithMaps = sparkContext.parallelize(1 to 600, 2) - .map(i => (i, Map(i -> i.toString))) - .mapPartitions(scala.util.Random.shuffle(_)).toDF("int", "map") - val dataWithArrayOfMaps = sparkContext.parallelize(1 to 600, 2) - .map(i => (i, Array(Map(i -> i.toString)))) - .mapPartitions(scala.util.Random.shuffle(_)).toDF("int", "arrayOfMaps") - - testNonOverlappingSplits(dataWithInts) - testNonOverlappingSplits(dataWithMaps) - testNonOverlappingSplits(dataWithArrayOfMaps) } test("pearson correlation") {