Skip to content

Commit

Permalink
fix a bug in GLM when intercept is not used
Browse files Browse the repository at this point in the history
  • Loading branch information
mengxr committed Mar 26, 2014
1 parent 8237df8 commit d7f629f
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 10 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -136,25 +136,28 @@ abstract class GeneralizedLinearAlgorithm[M <: GeneralizedLinearModel]

// Prepend an extra variable consisting of all 1.0's for the intercept.
val data = if (addIntercept) {
input.map(labeledPoint => (labeledPoint.label, labeledPoint.features.+:(1.0)))
input.map(labeledPoint => (labeledPoint.label, 1.0 +: labeledPoint.features))
} else {
input.map(labeledPoint => (labeledPoint.label, labeledPoint.features))
}

val initialWeightsWithIntercept = if (addIntercept) {
initialWeights.+:(1.0)
0.0 +: initialWeights
} else {
initialWeights
}

val weights = optimizer.optimize(data, initialWeightsWithIntercept)
val intercept = weights(0)
val weightsScaled = weights.tail
val weightsWithIntercept = optimizer.optimize(data, initialWeightsWithIntercept)

val model = createModel(weightsScaled, intercept)
val (intercept, weights) = if (addIntercept) {
(weightsWithIntercept(0), weightsWithIntercept.tail)
} else {
(0.0, weightsWithIntercept)
}

logInfo("Final weights " + weights.mkString(","))
logInfo("Final intercept " + intercept)

logInfo("Final model weights " + model.weights.mkString(","))
logInfo("Final model intercept " + model.intercept)
model
createModel(weights, intercept)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@

package org.apache.spark.mllib.regression

import org.scalatest.BeforeAndAfterAll
import org.scalatest.FunSuite

import org.apache.spark.mllib.util.{LinearDataGenerator, LocalSparkContext}
Expand Down Expand Up @@ -57,4 +56,29 @@ class LinearRegressionSuite extends FunSuite with LocalSparkContext {
// Test prediction on Array.
validatePrediction(validationData.map(row => model.predict(row.features)), validationData)
}

// Test if we can correctly learn Y = 10*X1 + 10*X2
test("linear regression without intercept") {
val testRDD = sc.parallelize(LinearDataGenerator.generateLinearInput(
0.0, Array(10.0, 10.0), 100, 42), 2).cache()
val linReg = new LinearRegressionWithSGD().setIntercept(false)
linReg.optimizer.setNumIterations(1000).setStepSize(1.0)

val model = linReg.run(testRDD)

assert(model.intercept === 0.0)
assert(model.weights.length === 2)
assert(model.weights(0) >= 9.0 && model.weights(0) <= 11.0)
assert(model.weights(1) >= 9.0 && model.weights(1) <= 11.0)

val validationData = LinearDataGenerator.generateLinearInput(
0.0, Array(10.0, 10.0), 100, 17)
val validationRDD = sc.parallelize(validationData, 2).cache()

// Test prediction on RDD.
validatePrediction(model.predict(validationRDD.map(_.features)).collect(), validationData)

// Test prediction on Array.
validatePrediction(validationData.map(row => model.predict(row.features)), validationData)
}
}

0 comments on commit d7f629f

Please sign in to comment.