Skip to content

Commit

Permalink
first commit
Browse files Browse the repository at this point in the history
  • Loading branch information
DB Tsai committed Jan 7, 2015
1 parent e21acc1 commit 64eefd0
Show file tree
Hide file tree
Showing 4 changed files with 34 additions and 19 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ package org.apache.spark.mllib.optimization
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.mllib.linalg.{Vector, Vectors}
import org.apache.spark.mllib.linalg.BLAS.{axpy, dot, scal}
import org.apache.spark.mllib.util.MLUtils

/**
* :: DeveloperApi ::
Expand Down Expand Up @@ -64,17 +65,12 @@ class LogisticGradient extends Gradient {
val gradientMultiplier = (1.0 / (1.0 + math.exp(margin))) - label
val gradient = data.copy
scal(gradientMultiplier, gradient)
val minusYP = if (label > 0) margin else -margin

// log1p is log(1+p) but more accurate for small p
// Following two equations are the same analytically but not numerically, e.g.,
// math.log1p(math.exp(1000)) == Infinity
// 1000 + math.log1p(math.exp(-1000)) == 1000.0
val loss =
if (minusYP < 0) {
math.log1p(math.exp(minusYP))
if (label > 0) {
// The following is equivalent to log(1 + exp(margin)) but more numerically stable.
MLUtils.log1pExp(margin)
} else {
math.log1p(math.exp(-minusYP)) + minusYP
MLUtils.log1pExp(margin) - margin
}

(gradient, loss)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ package org.apache.spark.mllib.tree.loss
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.tree.model.TreeEnsembleModel
import org.apache.spark.mllib.util.MLUtils
import org.apache.spark.rdd.RDD

/**
Expand Down Expand Up @@ -61,13 +62,8 @@ object LogLoss extends Loss {
data.map { case point =>
val prediction = model.predict(point.features)
val margin = 2.0 * point.label * prediction
// The following are equivalent to 2.0 * log(1 + exp(-margin)) but are more numerically
// stable.
if (margin >= 0) {
2.0 * math.log1p(math.exp(-margin))
} else {
2.0 * (-margin + math.log1p(math.exp(margin)))
}
// The following is equivalent to 2.0 * log(1 + exp(-margin)) but more numerically stable.
2.0 * MLUtils.log1pExp(-margin)
}.mean()
}
}
16 changes: 16 additions & 0 deletions mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala
Original file line number Diff line number Diff line change
Expand Up @@ -322,4 +322,20 @@ object MLUtils {
}
sqDist
}

/**
* When `x` is positive and large, computing `math.log(1 + math.exp(x))` will lead to arithmetic
* overflow. This will happen when `x > 709.78` which is not a very large number.
* It can be addressed by rewriting the formula into `x + math.log1p(math.exp(-x))` when `x > 0`.
*
* @param x a floating-point value as input.
* @return the result of `math.log(1 + math.exp(x))`.
*/
private[mllib] def log1pExp(x: Double): Double = {
if (x > 0) {
x + math.log1p(math.exp(-x))
} else {
math.log1p(math.exp(x))
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,18 +20,17 @@ package org.apache.spark.mllib.util
import java.io.File

import scala.io.Source
import scala.math

import org.scalatest.FunSuite

import breeze.linalg.{DenseVector => BDV, SparseVector => BSV, norm => breezeNorm,
squaredDistance => breezeSquaredDistance}
import breeze.linalg.{squaredDistance => breezeSquaredDistance}
import com.google.common.base.Charsets
import com.google.common.io.Files

import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vectors}
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.util.MLUtils._
import org.apache.spark.mllib.util.TestingUtils._
import org.apache.spark.util.Utils

class MLUtilsSuite extends FunSuite with MLlibTestSparkContext {
Expand Down Expand Up @@ -204,4 +203,12 @@ class MLUtilsSuite extends FunSuite with MLlibTestSparkContext {
assert(points.collect().toSet === loaded.collect().toSet)
Utils.deleteRecursively(tempDir)
}

test("log1pExp") {
assert(log1pExp(76.3) ~== math.log1p(math.exp(76.3)) relTol 1E-10)
assert(log1pExp(87296763.234) ~== 87296763.234 relTol 1E-10)

assert(log1pExp(-13.8) ~== math.log1p(math.exp(-13.8)) absTol 1E-10)
assert(log1pExp(-238423789.865) ~== math.log1p(math.exp(-238423789.865)) absTol 1E-10)
}
}

0 comments on commit 64eefd0

Please sign in to comment.