diff --git a/mllib/src/main/scala/org/apache/spark/mllib/optimization/Gradient.scala b/mllib/src/main/scala/org/apache/spark/mllib/optimization/Gradient.scala index ed24e35425a59..6fb911f65f69f 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/optimization/Gradient.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/optimization/Gradient.scala @@ -94,12 +94,24 @@ class LogisticGradient extends Gradient { weights: Vector, cumGradient: Vector): Double = { val margin = -1.0 * dot(data, weights) - val gradientMultiplier = (1.0 / (1.0 + math.exp(margin))) - label + /** + * gradientMultiplier = (1.0 / (1.0 + math.exp(margin))) - label + * However, the first part of gradientMultiplier can be potentially suffered from overflow, + * so we use the equivalent formula but more numerically stable. + */ + val gradientMultiplier = + if (margin > 0.0) { + val temp = math.exp(-margin) + temp / (1.0 + temp) - label + } else { + 1.0 / (1.0 + math.exp(margin)) - label + } axpy(gradientMultiplier, data, cumGradient) if (label > 0) { - math.log1p(math.exp(margin)) + // The following is equivalent to log(1 + exp(margin)) but more numerically stable. + MLUtils.log1pExp(margin) } else { - math.log1p(math.exp(margin)) - margin + MLUtils.log1pExp(margin) - margin } } }