Skip to content
This repository has been archived by the owner on Sep 18, 2023. It is now read-only.

[NSE-788] Quick fix for randomSplit on reordered partitions #789

Merged
merged 1 commit into from
Mar 23, 2022
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -69,38 +69,40 @@ class DataFrameStatSuite extends QueryTest with SharedSparkSession {
}

test("randomSplit on reordered partitions") {
// TODO: Once fixed, will removed this config
withSQLConf("spark.oap.sql.columnar.enableComplexType" -> "false") {
def testNonOverlappingSplits(data: DataFrame): Unit = {
val splits = data.randomSplit(Array[Double](2, 3), seed = 1)
assert(splits.length == 2, "wrong number of splits")

// Verify that the splits span the entire dataset
assert(splits.flatMap(_.collect()).toSet == data.collect().toSet)

// Verify that the splits don't overlap
assert(splits(0).collect().toSeq.intersect(splits(1).collect().toSeq).isEmpty)

// Verify that the results are deterministic across multiple runs
val firstRun = splits.toSeq.map(_.collect().toSeq)
val secondRun = data.randomSplit(Array[Double](2, 3), seed = 1).toSeq.map(_.collect().toSeq)
assert(firstRun == secondRun)
}

def testNonOverlappingSplits(data: DataFrame): Unit = {
val splits = data.randomSplit(Array[Double](2, 3), seed = 1)
assert(splits.length == 2, "wrong number of splits")

// Verify that the splits span the entire dataset
assert(splits.flatMap(_.collect()).toSet == data.collect().toSet)

// Verify that the splits don't overlap
assert(splits(0).collect().toSeq.intersect(splits(1).collect().toSeq).isEmpty)

// Verify that the results are deterministic across multiple runs
val firstRun = splits.toSeq.map(_.collect().toSeq)
val secondRun = data.randomSplit(Array[Double](2, 3), seed = 1).toSeq.map(_.collect().toSeq)
assert(firstRun == secondRun)
// This test ensures that randomSplit does not create overlapping splits even when the
// underlying dataframe (such as the one below) doesn't guarantee a deterministic ordering of
// rows in each partition.
val dataWithInts = sparkContext.parallelize(1 to 600, 2)
.mapPartitions(scala.util.Random.shuffle(_)).toDF("int")
val dataWithMaps = sparkContext.parallelize(1 to 600, 2)
.map(i => (i, Map(i -> i.toString)))
.mapPartitions(scala.util.Random.shuffle(_)).toDF("int", "map")
val dataWithArrayOfMaps = sparkContext.parallelize(1 to 600, 2)
.map(i => (i, Array(Map(i -> i.toString))))
.mapPartitions(scala.util.Random.shuffle(_)).toDF("int", "arrayOfMaps")

testNonOverlappingSplits(dataWithInts)
testNonOverlappingSplits(dataWithMaps)
testNonOverlappingSplits(dataWithArrayOfMaps)
}

// This test ensures that randomSplit does not create overlapping splits even when the
// underlying dataframe (such as the one below) doesn't guarantee a deterministic ordering of
// rows in each partition.
val dataWithInts = sparkContext.parallelize(1 to 600, 2)
.mapPartitions(scala.util.Random.shuffle(_)).toDF("int")
val dataWithMaps = sparkContext.parallelize(1 to 600, 2)
.map(i => (i, Map(i -> i.toString)))
.mapPartitions(scala.util.Random.shuffle(_)).toDF("int", "map")
val dataWithArrayOfMaps = sparkContext.parallelize(1 to 600, 2)
.map(i => (i, Array(Map(i -> i.toString))))
.mapPartitions(scala.util.Random.shuffle(_)).toDF("int", "arrayOfMaps")

testNonOverlappingSplits(dataWithInts)
testNonOverlappingSplits(dataWithMaps)
testNonOverlappingSplits(dataWithArrayOfMaps)
}

test("pearson correlation") {
Expand Down