Skip to content

Commit

Permalink
update logisticAggregatorSuite
Browse files Browse the repository at this point in the history
  • Loading branch information
WeichenXu123 committed Aug 21, 2017
1 parent 8515b20 commit 0f28e5e
Showing 1 changed file with 11 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -238,8 +238,17 @@ class LogisticAggregatorSuite extends SparkFunSuite with MLlibTestSparkContext {
val aggConstantFeature = getNewAggregator(instancesConstantFeature,
Vectors.dense(coefArray ++ interceptArray), fitIntercept = true, isMultinomial = true)
instances.foreach(aggConstantFeature.add)

// constant features should not affect gradient
assert(aggConstantFeature.gradient(0) === 0.0)
def validateGradient(grad: Vector): Unit = {
assert(grad(0) === 0.0)
grad.toArray.foreach { gradientValue =>
assert(!gradientValue.isNaN &&
gradientValue > Double.NegativeInfinity && gradientValue < Double.PositiveInfinity)
}
}

validateGradient(aggConstantFeature.gradient)

val binaryCoefArray = Array(1.0, 2.0)
val intercept = 1.0
Expand All @@ -248,6 +257,6 @@ class LogisticAggregatorSuite extends SparkFunSuite with MLlibTestSparkContext {
isMultinomial = false)
instances.foreach(aggConstantFeatureBinary.add)
// constant features should not affect gradient
assert(aggConstantFeatureBinary.gradient(0) === 0.0)
validateGradient(aggConstantFeatureBinary.gradient)
}
}

0 comments on commit 0f28e5e

Please sign in to comment.