diff --git a/mllib/src/main/scala/org/apache/spark/mllib/MLContext.scala b/mllib/src/main/scala/org/apache/spark/mllib/MLContext.scala index 3cd09b0d48113..00988bc480dc8 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/MLContext.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/MLContext.scala @@ -39,6 +39,7 @@ class MLContext(val sparkContext: SparkContext) { * where the feature indices are converted to zero-based. * * @param path file or directory path in any Hadoop-supported file system URI + * @param minSplits min number of partitions, default: sparkContext.defaultMinSplits * @param numFeatures number of features, which will be determined from the input data if a * non-positive value is given. The default value is 0. * @param labelParser parser for labels, default: _.toDouble @@ -46,9 +47,13 @@ class MLContext(val sparkContext: SparkContext) { */ def libSVMFile( path: String, + minSplits: Int = sparkContext.defaultMinSplits, numFeatures: Int = 0, labelParser: String => Double = _.toDouble): RDD[LabeledPoint] = { - val parsed = sparkContext.textFile(path).map(_.trim).filter(!_.isEmpty).map(_.split(' ')) + val parsed = sparkContext.textFile(path, minSplits) + .map(_.trim) + .filter(!_.isEmpty) + .map(_.split(' ')) // Determine number of features. val d = if (numFeatures > 0) { numFeatures diff --git a/mllib/src/test/scala/org/apache/spark/mllib/MLContextSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/MLContextSuite.scala index 05be434590c48..6313978d546b9 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/MLContextSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/MLContextSuite.scala @@ -41,8 +41,8 @@ class MLContextSuite extends FunSuite with LocalSparkContext { val mlc = MLContext(sc) - val pointsWithNumFeatures = mlc.libSVMFile(tempDir.toURI.toString, 6).collect() - val pointsWithoutNumFeatures = mlc.libSVMFile(tempDir.toURI.toString, 0).collect() + val pointsWithNumFeatures = mlc.libSVMFile(tempDir.toURI.toString, numFeatures = 6).collect() + val pointsWithoutNumFeatures = mlc.libSVMFile(tempDir.toURI.toString).collect() for (points <- Seq(pointsWithNumFeatures, pointsWithoutNumFeatures)) { assert(points.length === 3)