Skip to content

Commit

Permalink
[SPARK-5128][MLLib] Add common used log1pExp API in MLUtils
Browse files Browse the repository at this point in the history
When `x` is positive and large, computing `math.log(1 + math.exp(x))` will lead to arithmetic
overflow. This will happen when `x > 709.78` which is not a very large number.
It can be addressed by rewriting the formula into `x + math.log1p(math.exp(-x))` when `x > 0`.

Author: DB Tsai <dbtsai@alpinenow.com>

Closes apache#3915 from dbtsai/mathutil and squashes the following commits:

bec6a84 [DB Tsai] remove empty line
3239541 [DB Tsai] revert part of patch into another PR
23144f3 [DB Tsai] doc
49f3658 [DB Tsai] temp
6c29ed3 [DB Tsai] formating
f8447f9 [DB Tsai] address another overflow issue in gradientMultiplier in LOR gradient code
64eefd0 [DB Tsai] first commit
  • Loading branch information
DB Tsai authored and mengxr committed Jan 7, 2015
1 parent 6e74ede commit 60e2d9e
Show file tree
Hide file tree
Showing 4 changed files with 37 additions and 21 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ package org.apache.spark.mllib.optimization
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.mllib.linalg.{Vector, Vectors}
import org.apache.spark.mllib.linalg.BLAS.{axpy, dot, scal}
import org.apache.spark.mllib.util.MLUtils

/**
* :: DeveloperApi ::
Expand Down Expand Up @@ -64,17 +65,12 @@ class LogisticGradient extends Gradient {
val gradientMultiplier = (1.0 / (1.0 + math.exp(margin))) - label
val gradient = data.copy
scal(gradientMultiplier, gradient)
val minusYP = if (label > 0) margin else -margin

// log1p is log(1+p) but more accurate for small p
// Following two equations are the same analytically but not numerically, e.g.,
// math.log1p(math.exp(1000)) == Infinity
// 1000 + math.log1p(math.exp(-1000)) == 1000.0
val loss =
if (minusYP < 0) {
math.log1p(math.exp(minusYP))
if (label > 0) {
// The following is equivalent to log(1 + exp(margin)) but more numerically stable.
MLUtils.log1pExp(margin)
} else {
math.log1p(math.exp(-minusYP)) + minusYP
MLUtils.log1pExp(margin) - margin
}

(gradient, loss)
Expand All @@ -89,9 +85,10 @@ class LogisticGradient extends Gradient {
val gradientMultiplier = (1.0 / (1.0 + math.exp(margin))) - label
axpy(gradientMultiplier, data, cumGradient)
if (label > 0) {
math.log1p(math.exp(margin))
// The following is equivalent to log(1 + exp(margin)) but more numerically stable.
MLUtils.log1pExp(margin)
} else {
math.log1p(math.exp(margin)) - margin
MLUtils.log1pExp(margin) - margin
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ package org.apache.spark.mllib.tree.loss
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.tree.model.TreeEnsembleModel
import org.apache.spark.mllib.util.MLUtils
import org.apache.spark.rdd.RDD

/**
Expand Down Expand Up @@ -61,13 +62,8 @@ object LogLoss extends Loss {
data.map { case point =>
val prediction = model.predict(point.features)
val margin = 2.0 * point.label * prediction
// The following are equivalent to 2.0 * log(1 + exp(-margin)) but are more numerically
// stable.
if (margin >= 0) {
2.0 * math.log1p(math.exp(-margin))
} else {
2.0 * (-margin + math.log1p(math.exp(margin)))
}
// The following is equivalent to 2.0 * log(1 + exp(-margin)) but more numerically stable.
2.0 * MLUtils.log1pExp(-margin)
}.mean()
}
}
16 changes: 16 additions & 0 deletions mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala
Original file line number Diff line number Diff line change
Expand Up @@ -322,4 +322,20 @@ object MLUtils {
}
sqDist
}

/**
* When `x` is positive and large, computing `math.log(1 + math.exp(x))` will lead to arithmetic
* overflow. This will happen when `x > 709.78` which is not a very large number.
* It can be addressed by rewriting the formula into `x + math.log1p(math.exp(-x))` when `x > 0`.
*
* @param x a floating-point value as input.
* @return the result of `math.log(1 + math.exp(x))`.
*/
private[mllib] def log1pExp(x: Double): Double = {
if (x > 0) {
x + math.log1p(math.exp(-x))
} else {
math.log1p(math.exp(x))
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,18 +20,17 @@ package org.apache.spark.mllib.util
import java.io.File

import scala.io.Source
import scala.math

import org.scalatest.FunSuite

import breeze.linalg.{DenseVector => BDV, SparseVector => BSV, norm => breezeNorm,
squaredDistance => breezeSquaredDistance}
import breeze.linalg.{squaredDistance => breezeSquaredDistance}
import com.google.common.base.Charsets
import com.google.common.io.Files

import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vectors}
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.util.MLUtils._
import org.apache.spark.mllib.util.TestingUtils._
import org.apache.spark.util.Utils

class MLUtilsSuite extends FunSuite with MLlibTestSparkContext {
Expand Down Expand Up @@ -204,4 +203,12 @@ class MLUtilsSuite extends FunSuite with MLlibTestSparkContext {
assert(points.collect().toSet === loaded.collect().toSet)
Utils.deleteRecursively(tempDir)
}

test("log1pExp") {
assert(log1pExp(76.3) ~== math.log1p(math.exp(76.3)) relTol 1E-10)
assert(log1pExp(87296763.234) ~== 87296763.234 relTol 1E-10)

assert(log1pExp(-13.8) ~== math.log1p(math.exp(-13.8)) absTol 1E-10)
assert(log1pExp(-238423789.865) ~== math.log1p(math.exp(-238423789.865)) absTol 1E-10)
}
}

0 comments on commit 60e2d9e

Please sign in to comment.