From 41fce544cadce5ed314b75f368abf79ee7fcd2da Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Mon, 10 Nov 2014 15:54:20 -0800 Subject: [PATCH] randomSplit() --- .../apache/spark/api/python/PythonRDD.scala | 13 +++++++++ python/pyspark/rdd.py | 28 +++++++++++++++++++ 2 files changed, 41 insertions(+) diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala index 45beb8fc8c925..78a5794bd557b 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala @@ -757,6 +757,19 @@ private[spark] object PythonRDD extends Logging { converted.saveAsHadoopDataset(new JobConf(conf)) } } + + /** + * A helper to convert java.util.List[Double] into Array[Double] + * @param list + * @return + */ + def listToArrayDouble(list: JList[Double]): Array[Double] = { + val r = new Array[Double](list.size) + list.zipWithIndex.foreach { + case (v, i) => r(i) = v + } + r + } } private diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index 08d047402625f..f29af793737f8 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -316,6 +316,34 @@ def sample(self, withReplacement, fraction, seed=None): assert fraction >= 0.0, "Negative fraction value: %s" % fraction return self.mapPartitionsWithIndex(RDDSampler(withReplacement, fraction, seed).func, True) + def randomSplit(self, weights, seed=None): + """ + Randomly splits this RDD with the provided weights. + + :param weights: weights for splits, will be normalized if they don't sum to 1 + :param seed: random seed + :return: split RDDs in an list + + >>> rdd = sc.parallelize(range(10), 1) + >>> rdd1, rdd2, rdd3 = rdd.randomSplit([0.4, 0.6, 1.0], 11) + >>> rdd1.collect() + [3, 6] + >>> rdd2.collect() + [0, 5, 7] + >>> rdd3.collect() + [1, 2, 4, 8, 9] + """ + ser = BatchedSerializer(PickleSerializer(), 1) + rdd = self._reserialize(ser) + jweights = ListConverter().convert([float(w) for w in weights], + self.ctx._gateway._gateway_client) + jweights = self.ctx._jvm.PythonRDD.listToArrayDouble(jweights) + if seed is None: + jrdds = rdd._jrdd.randomSplit(jweights) + else: + jrdds = rdd._jrdd.randomSplit(jweights, seed) + return [RDD(jrdd, self.ctx, ser) for jrdd in jrdds] + # this is ported from scala/spark/RDD.scala def takeSample(self, withReplacement, num, seed=None): """