Skip to content

Commit

Permalink
[SPARK-4907][MLlib] Inconsistent loss and gradient in LeastSquaresGra…
Browse files Browse the repository at this point in the history
…dient compared with R

In most of the academic paper and algorithm implementations,
people use L = 1/2n ||A weights-y||^2 instead of L = 1/n ||A weights-y||^2
for least-squared loss. See Eq. (1) in http://web.stanford.edu/~hastie/Papers/glmnet.pdf

Since MLlib uses different convention, this will result different residuals and
all the stats properties will be different from GLMNET package in R.

The model coefficients will be still the same under this change.

Author: DB Tsai <dbtsai@alpinenow.com>

Closes #3746 from dbtsai/lir and squashes the following commits:

19c2e85 [DB Tsai] make stepsize twice to converge to the same solution
0b2c29c [DB Tsai] first commit
  • Loading branch information
DB Tsai authored and mengxr committed Dec 23, 2014
1 parent c233ab3 commit a96b727
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -94,16 +94,16 @@ class LogisticGradient extends Gradient {
* :: DeveloperApi ::
* Compute gradient and loss for a Least-squared loss function, as used in linear regression.
* This is correct for the averaged least squares loss function (mean squared error)
* L = 1/n ||A weights-y||^2
* L = 1/2n ||A weights-y||^2
* See also the documentation for the precise formulation.
*/
@DeveloperApi
class LeastSquaresGradient extends Gradient {
override def compute(data: Vector, label: Double, weights: Vector): (Vector, Double) = {
val diff = dot(data, weights) - label
val loss = diff * diff
val loss = diff * diff / 2.0
val gradient = data.copy
scal(2.0 * diff, gradient)
scal(diff, gradient)
(gradient, loss)
}

Expand All @@ -113,8 +113,8 @@ class LeastSquaresGradient extends Gradient {
weights: Vector,
cumGradient: Vector): Double = {
val diff = dot(data, weights) - label
axpy(2.0 * diff, data, cumGradient)
diff * diff
axpy(diff, data, cumGradient)
diff * diff / 2.0
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ class StreamingLinearRegressionSuite extends FunSuite with TestSuiteBase {
// create model
val model = new StreamingLinearRegressionWithSGD()
.setInitialWeights(Vectors.dense(0.0, 0.0))
.setStepSize(0.1)
.setStepSize(0.2)
.setNumIterations(25)

// generate sequence of simulated data
Expand Down Expand Up @@ -84,7 +84,7 @@ class StreamingLinearRegressionSuite extends FunSuite with TestSuiteBase {
// create model
val model = new StreamingLinearRegressionWithSGD()
.setInitialWeights(Vectors.dense(0.0))
.setStepSize(0.1)
.setStepSize(0.2)
.setNumIterations(25)

// generate sequence of simulated data
Expand Down Expand Up @@ -118,7 +118,7 @@ class StreamingLinearRegressionSuite extends FunSuite with TestSuiteBase {
// create model initialized with true weights
val model = new StreamingLinearRegressionWithSGD()
.setInitialWeights(Vectors.dense(10.0, 10.0))
.setStepSize(0.1)
.setStepSize(0.2)
.setNumIterations(25)

// generate sequence of simulated data for testing
Expand Down

2 comments on commit a96b727

@martinjaggi
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

minor comment: actually this will only maintain the same models for the unregularized case. (otherwise people might want to check their choice of the regularization param).

but in any case, indeed nice to have consistent 1/(2n) notation!

TODO: update the same in regression/Lasso.scala, regression/RidgeRegression.scala (and their corresponding test-cases) for consistency.

@dbtsai
Copy link
Member

@dbtsai dbtsai commented on a96b727 Dec 24, 2014

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@martinjaggi I agreed. I planed to have another PR to fix lasso, and ridge regression since they are not computing the regularization and intercept correctly. In my next PR, I'll generalize it to elastic-net regularization and the solutions will be exactly the same as R's glmnet package.

Please sign in to comment.