From d7f629f902aab81cf3637f07f9eb9f7119d9230c Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Tue, 25 Mar 2014 23:05:35 -0700 Subject: [PATCH 1/2] fix a bug in GLM when intercept is not used --- .../GeneralizedLinearAlgorithm.scala | 21 ++++++++------- .../regression/LinearRegressionSuite.scala | 26 ++++++++++++++++++- 2 files changed, 37 insertions(+), 10 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala index b9621530efa22..3e1ed91bf6729 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala @@ -136,25 +136,28 @@ abstract class GeneralizedLinearAlgorithm[M <: GeneralizedLinearModel] // Prepend an extra variable consisting of all 1.0's for the intercept. val data = if (addIntercept) { - input.map(labeledPoint => (labeledPoint.label, labeledPoint.features.+:(1.0))) + input.map(labeledPoint => (labeledPoint.label, 1.0 +: labeledPoint.features)) } else { input.map(labeledPoint => (labeledPoint.label, labeledPoint.features)) } val initialWeightsWithIntercept = if (addIntercept) { - initialWeights.+:(1.0) + 0.0 +: initialWeights } else { initialWeights } - val weights = optimizer.optimize(data, initialWeightsWithIntercept) - val intercept = weights(0) - val weightsScaled = weights.tail + val weightsWithIntercept = optimizer.optimize(data, initialWeightsWithIntercept) - val model = createModel(weightsScaled, intercept) + val (intercept, weights) = if (addIntercept) { + (weightsWithIntercept(0), weightsWithIntercept.tail) + } else { + (0.0, weightsWithIntercept) + } + + logInfo("Final weights " + weights.mkString(",")) + logInfo("Final intercept " + intercept) - logInfo("Final model weights " + model.weights.mkString(",")) - logInfo("Final model intercept " + model.intercept) - model + createModel(weights, intercept) } } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/regression/LinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/regression/LinearRegressionSuite.scala index 281f9df36ddb3..5d251bcbf35db 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/regression/LinearRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/regression/LinearRegressionSuite.scala @@ -17,7 +17,6 @@ package org.apache.spark.mllib.regression -import org.scalatest.BeforeAndAfterAll import org.scalatest.FunSuite import org.apache.spark.mllib.util.{LinearDataGenerator, LocalSparkContext} @@ -57,4 +56,29 @@ class LinearRegressionSuite extends FunSuite with LocalSparkContext { // Test prediction on Array. validatePrediction(validationData.map(row => model.predict(row.features)), validationData) } + + // Test if we can correctly learn Y = 10*X1 + 10*X2 + test("linear regression without intercept") { + val testRDD = sc.parallelize(LinearDataGenerator.generateLinearInput( + 0.0, Array(10.0, 10.0), 100, 42), 2).cache() + val linReg = new LinearRegressionWithSGD().setIntercept(false) + linReg.optimizer.setNumIterations(1000).setStepSize(1.0) + + val model = linReg.run(testRDD) + + assert(model.intercept === 0.0) + assert(model.weights.length === 2) + assert(model.weights(0) >= 9.0 && model.weights(0) <= 11.0) + assert(model.weights(1) >= 9.0 && model.weights(1) <= 11.0) + + val validationData = LinearDataGenerator.generateLinearInput( + 0.0, Array(10.0, 10.0), 100, 17) + val validationRDD = sc.parallelize(validationData, 2).cache() + + // Test prediction on RDD. + validatePrediction(model.predict(validationRDD.map(_.features)).collect(), validationData) + + // Test prediction on Array. + validatePrediction(validationData.map(row => model.predict(row.features)), validationData) + } } From 0e57aa43f61a62a70faf27aed58dea201b494809 Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Wed, 26 Mar 2014 11:44:48 -0700 Subject: [PATCH 2/2] update Lasso and RidgeRegression to parse the weights correctly from GLM mark createModel protected mark predictPoint protected --- .../GeneralizedLinearAlgorithm.scala | 2 +- .../apache/spark/mllib/regression/Lasso.scala | 20 +++++++++++++------ .../mllib/regression/LinearRegression.scala | 20 +++++++++---------- .../mllib/regression/RidgeRegression.scala | 18 ++++++++++++----- 4 files changed, 38 insertions(+), 22 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala index 3e1ed91bf6729..2166c6bb6b443 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala @@ -44,7 +44,7 @@ abstract class GeneralizedLinearModel(val weights: Array[Double], val intercept: * @param weightMatrix Column vector containing the weights of the model * @param intercept Intercept of the model. */ - def predictPoint(dataMatrix: DoubleMatrix, weightMatrix: DoubleMatrix, + protected def predictPoint(dataMatrix: DoubleMatrix, weightMatrix: DoubleMatrix, intercept: Double): Double /** diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/Lasso.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/Lasso.scala index fb2bc9b92a51c..e397a573079e5 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/regression/Lasso.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/Lasso.scala @@ -36,8 +36,10 @@ class LassoModel( extends GeneralizedLinearModel(weights, intercept) with RegressionModel with Serializable { - override def predictPoint(dataMatrix: DoubleMatrix, weightMatrix: DoubleMatrix, - intercept: Double) = { + override protected def predictPoint( + dataMatrix: DoubleMatrix, + weightMatrix: DoubleMatrix, + intercept: Double): Double = { dataMatrix.dot(weightMatrix) + intercept } } @@ -66,7 +68,7 @@ class LassoWithSGD private ( .setMiniBatchFraction(miniBatchFraction) // We don't want to penalize the intercept, so set this to false. - setIntercept(false) + super.setIntercept(false) var yMean = 0.0 var xColMean: DoubleMatrix = _ @@ -77,10 +79,16 @@ class LassoWithSGD private ( */ def this() = this(1.0, 100, 1.0, 1.0) - def createModel(weights: Array[Double], intercept: Double) = { - val weightsMat = new DoubleMatrix(weights.length + 1, 1, (Array(intercept) ++ weights):_*) + override def setIntercept(addIntercept: Boolean): this.type = { + // TODO: Support adding intercept. + if (addIntercept) throw new UnsupportedOperationException("Adding intercept is not supported.") + this + } + + override protected def createModel(weights: Array[Double], intercept: Double) = { + val weightsMat = new DoubleMatrix(weights.length, 1, weights: _*) val weightsScaled = weightsMat.div(xColSd) - val interceptScaled = yMean - (weightsMat.transpose().mmul(xColMean.div(xColSd)).get(0)) + val interceptScaled = yMean - weightsMat.transpose().mmul(xColMean.div(xColSd)).get(0) new LassoModel(weightsScaled.data, interceptScaled) } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/LinearRegression.scala index 8ee40addb25d9..b4aafbe8bcaff 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/regression/LinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/LinearRegression.scala @@ -31,13 +31,14 @@ import org.jblas.DoubleMatrix * @param intercept Intercept computed for this model. */ class LinearRegressionModel( - override val weights: Array[Double], - override val intercept: Double) - extends GeneralizedLinearModel(weights, intercept) - with RegressionModel with Serializable { - - override def predictPoint(dataMatrix: DoubleMatrix, weightMatrix: DoubleMatrix, - intercept: Double) = { + override val weights: Array[Double], + override val intercept: Double) + extends GeneralizedLinearModel(weights, intercept) with RegressionModel with Serializable { + + override protected def predictPoint( + dataMatrix: DoubleMatrix, + weightMatrix: DoubleMatrix, + intercept: Double): Double = { dataMatrix.dot(weightMatrix) + intercept } } @@ -55,8 +56,7 @@ class LinearRegressionWithSGD private ( var stepSize: Double, var numIterations: Int, var miniBatchFraction: Double) - extends GeneralizedLinearAlgorithm[LinearRegressionModel] - with Serializable { + extends GeneralizedLinearAlgorithm[LinearRegressionModel] with Serializable { val gradient = new LeastSquaresGradient() val updater = new SimpleUpdater() @@ -69,7 +69,7 @@ class LinearRegressionWithSGD private ( */ def this() = this(1.0, 100, 1.0) - def createModel(weights: Array[Double], intercept: Double) = { + override protected def createModel(weights: Array[Double], intercept: Double) = { new LinearRegressionModel(weights, intercept) } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/RidgeRegression.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/RidgeRegression.scala index c504d3d40c773..325e78c8f2233 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/regression/RidgeRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/RidgeRegression.scala @@ -36,8 +36,10 @@ class RidgeRegressionModel( extends GeneralizedLinearModel(weights, intercept) with RegressionModel with Serializable { - override def predictPoint(dataMatrix: DoubleMatrix, weightMatrix: DoubleMatrix, - intercept: Double) = { + override protected def predictPoint( + dataMatrix: DoubleMatrix, + weightMatrix: DoubleMatrix, + intercept: Double): Double = { dataMatrix.dot(weightMatrix) + intercept } } @@ -67,7 +69,7 @@ class RidgeRegressionWithSGD private ( .setMiniBatchFraction(miniBatchFraction) // We don't want to penalize the intercept in RidgeRegression, so set this to false. - setIntercept(false) + super.setIntercept(false) var yMean = 0.0 var xColMean: DoubleMatrix = _ @@ -78,8 +80,14 @@ class RidgeRegressionWithSGD private ( */ def this() = this(1.0, 100, 1.0, 1.0) - def createModel(weights: Array[Double], intercept: Double) = { - val weightsMat = new DoubleMatrix(weights.length + 1, 1, (Array(intercept) ++ weights):_*) + override def setIntercept(addIntercept: Boolean): this.type = { + // TODO: Support adding intercept. + if (addIntercept) throw new UnsupportedOperationException("Adding intercept is not supported.") + this + } + + override protected def createModel(weights: Array[Double], intercept: Double) = { + val weightsMat = new DoubleMatrix(weights.length, 1, weights: _*) val weightsScaled = weightsMat.div(xColSd) val interceptScaled = yMean - weightsMat.transpose().mmul(xColMean.div(xColSd)).get(0)