Skip to content

Commit

Permalink
update testcase
Browse files Browse the repository at this point in the history
  • Loading branch information
WeichenXu123 committed Aug 21, 2017
1 parent 0f28e5e commit 1f4ba14
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 11 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -1457,9 +1457,9 @@ class LogisticRegressionSuite
*/

val coefficientsR = new DenseMatrix(3, 2, Array(
0.1881871, -0.0,
0.1881871, 0.0,
-0.02412645, 0.0,
-0.1640607, -0.0), isTransposed = true)
-0.1640607, 0.0), isTransposed = true)
val interceptsR = Vectors.dense(0.2658824, 0.53604701, -0.8019294)

model.coefficientMatrix.colIter.foreach(v => assert(v.toArray.sum ~== 0.0 absTol eps))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ class LogisticAggregatorSuite extends SparkFunSuite with MLlibTestSparkContext {

@transient var instances: Array[Instance] = _
@transient var instancesConstantFeature: Array[Instance] = _
@transient var instancesConstantFeatureFiltered: Array[Instance] = _

override def beforeAll(): Unit = {
super.beforeAll()
Expand All @@ -41,6 +42,11 @@ class LogisticAggregatorSuite extends SparkFunSuite with MLlibTestSparkContext {
Instance(1.0, 0.5, Vectors.dense(1.0, 1.0)),
Instance(2.0, 0.3, Vectors.dense(1.0, 0.5))
)
instancesConstantFeatureFiltered = Array(
Instance(0.0, 0.1, Vectors.dense(2.0)),
Instance(1.0, 0.5, Vectors.dense(1.0)),
Instance(2.0, 0.3, Vectors.dense(0.5))
)
}

/** Get summary statistics for some data and create a new LogisticAggregator. */
Expand Down Expand Up @@ -233,30 +239,44 @@ class LogisticAggregatorSuite extends SparkFunSuite with MLlibTestSparkContext {
val binaryInstances = instancesConstantFeature.map { instance =>
if (instance.label <= 1.0) instance else Instance(0.0, instance.weight, instance.features)
}
val binaryInstancesFiltered = instancesConstantFeatureFiltered.map { instance =>
if (instance.label <= 1.0) instance else Instance(0.0, instance.weight, instance.features)
}
val coefArray = Array(1.0, 2.0, -2.0, 3.0, 0.0, -1.0)
val coefArrayFiltered = Array(3.0, 0.0, -1.0)
val interceptArray = Array(4.0, 2.0, -3.0)
val aggConstantFeature = getNewAggregator(instancesConstantFeature,
Vectors.dense(coefArray ++ interceptArray), fitIntercept = true, isMultinomial = true)
instances.foreach(aggConstantFeature.add)
val aggConstantFeatureFiltered = getNewAggregator(instancesConstantFeatureFiltered,
Vectors.dense(coefArrayFiltered ++ interceptArray), fitIntercept = true, isMultinomial = true)

instancesConstantFeature.foreach(aggConstantFeature.add)
instancesConstantFeatureFiltered.foreach(aggConstantFeatureFiltered.add)

// constant features should not affect gradient
def validateGradient(grad: Vector): Unit = {
assert(grad(0) === 0.0)
grad.toArray.foreach { gradientValue =>
assert(!gradientValue.isNaN &&
gradientValue > Double.NegativeInfinity && gradientValue < Double.PositiveInfinity)
def validateGradient(grad: Vector, gradFiltered: Vector, numCoefficientSets: Int): Unit = {
for (i <- 0 until numCoefficientSets) {
assert(grad(i) === 0.0)
assert(grad(numCoefficientSets + i) == gradFiltered(i))
}
}

validateGradient(aggConstantFeature.gradient)
validateGradient(aggConstantFeature.gradient, aggConstantFeatureFiltered.gradient, 3)

val binaryCoefArray = Array(1.0, 2.0)
val binaryCoefArrayFiltered = Array(2.0)
val intercept = 1.0
val aggConstantFeatureBinary = getNewAggregator(binaryInstances,
Vectors.dense(binaryCoefArray ++ Array(intercept)), fitIntercept = true,
isMultinomial = false)
instances.foreach(aggConstantFeatureBinary.add)
val aggConstantFeatureBinaryFiltered = getNewAggregator(binaryInstancesFiltered,
Vectors.dense(binaryCoefArrayFiltered ++ Array(intercept)), fitIntercept = true,
isMultinomial = false)
binaryInstances.foreach(aggConstantFeatureBinary.add)
binaryInstancesFiltered.foreach(aggConstantFeatureBinaryFiltered.add)

// constant features should not affect gradient
validateGradient(aggConstantFeatureBinary.gradient)
validateGradient(aggConstantFeatureBinary.gradient,
aggConstantFeatureBinaryFiltered.gradient, 1)
}
}

0 comments on commit 1f4ba14

Please sign in to comment.