diff --git a/mllib/src/test/scala/org/apache/spark/mllib/regression/LinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/regression/LinearRegressionSuite.scala index aea3a2d6aa40f..efffab743d5f4 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/regression/LinearRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/regression/LinearRegressionSuite.scala @@ -20,6 +20,7 @@ package org.apache.spark.mllib.regression import org.scalatest.FunSuite import org.apache.spark.mllib.util.{LinearDataGenerator, LocalSparkContext} +import org.apache.spark.mllib.linalg.Vectors class LinearRegressionSuite extends FunSuite with LocalSparkContext { @@ -84,4 +85,37 @@ class LinearRegressionSuite extends FunSuite with LocalSparkContext { // Test prediction on Array. validatePrediction(validationData.map(row => model.predict(row.features)), validationData) } + + // Test if we can correctly learn Y = 10*X1 + 10*X10000 + test("sparse linear regression without intercept") { + val denseRDD = sc.parallelize(LinearDataGenerator.generateLinearInput(0.0, Array(10.0, 10.0), 100, 42), 2) + val sparseRDD = denseRDD.map { case LabeledPoint(label, v) => + val sv = Vectors.sparse(10000, Seq((0, v(0)), (9999, v(1)))) + LabeledPoint(label, sv) + }.cache() + val linReg = new LinearRegressionWithSGD().setIntercept(false) + linReg.optimizer.setNumIterations(1000).setStepSize(1.0) + + val model = linReg.run(sparseRDD) + + assert(model.intercept === 0.0) + + val weights = model.weights + assert(weights.size === 10000) + assert(weights(0) >= 9.0 && weights(0) <= 11.0) + assert(weights(9999) >= 9.0 && weights(9999) <= 11.0) + + val validationData = LinearDataGenerator.generateLinearInput(0.0, Array(10.0, 10.0), 100, 17) + val sparseValidationData = validationData.map { case LabeledPoint(label, v) => + val sv = Vectors.sparse(10000, Seq((0, v(0)), (9999, v(1)))) + LabeledPoint(label, sv) + } + val sparseValidationRDD = sc.parallelize(sparseValidationData, 2) + + // Test prediction on RDD. + validatePrediction(model.predict(sparseValidationRDD.map(_.features)).collect(), sparseValidationData) + + // Test prediction on Array. + validatePrediction(sparseValidationData.map(row => model.predict(row.features)), sparseValidationData) + } }