diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala index 07d395e78ee12..542977a48f0ac 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala @@ -1457,9 +1457,9 @@ class LogisticRegressionSuite */ val coefficientsR = new DenseMatrix(3, 2, Array( - 0.1881871, -0.0, + 0.1881871, 0.0, -0.02412645, 0.0, - -0.1640607, -0.0), isTransposed = true) + -0.1640607, 0.0), isTransposed = true) val interceptsR = Vectors.dense(0.2658824, 0.53604701, -0.8019294) model.coefficientMatrix.colIter.foreach(v => assert(v.toArray.sum ~== 0.0 absTol eps)) diff --git a/mllib/src/test/scala/org/apache/spark/ml/optim/aggregator/LogisticAggregatorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/optim/aggregator/LogisticAggregatorSuite.scala index bcbcb0e8efab9..16ef4af4f94e8 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/optim/aggregator/LogisticAggregatorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/optim/aggregator/LogisticAggregatorSuite.scala @@ -28,6 +28,7 @@ class LogisticAggregatorSuite extends SparkFunSuite with MLlibTestSparkContext { @transient var instances: Array[Instance] = _ @transient var instancesConstantFeature: Array[Instance] = _ + @transient var instancesConstantFeatureFiltered: Array[Instance] = _ override def beforeAll(): Unit = { super.beforeAll() @@ -41,6 +42,11 @@ class LogisticAggregatorSuite extends SparkFunSuite with MLlibTestSparkContext { Instance(1.0, 0.5, Vectors.dense(1.0, 1.0)), Instance(2.0, 0.3, Vectors.dense(1.0, 0.5)) ) + instancesConstantFeatureFiltered = Array( + Instance(0.0, 0.1, Vectors.dense(2.0)), + Instance(1.0, 0.5, Vectors.dense(1.0)), + Instance(2.0, 0.3, Vectors.dense(0.5)) + ) } /** Get summary statistics for some data and create a new LogisticAggregator. */ @@ -233,30 +239,44 @@ class LogisticAggregatorSuite extends SparkFunSuite with MLlibTestSparkContext { val binaryInstances = instancesConstantFeature.map { instance => if (instance.label <= 1.0) instance else Instance(0.0, instance.weight, instance.features) } + val binaryInstancesFiltered = instancesConstantFeatureFiltered.map { instance => + if (instance.label <= 1.0) instance else Instance(0.0, instance.weight, instance.features) + } val coefArray = Array(1.0, 2.0, -2.0, 3.0, 0.0, -1.0) + val coefArrayFiltered = Array(3.0, 0.0, -1.0) val interceptArray = Array(4.0, 2.0, -3.0) val aggConstantFeature = getNewAggregator(instancesConstantFeature, Vectors.dense(coefArray ++ interceptArray), fitIntercept = true, isMultinomial = true) - instances.foreach(aggConstantFeature.add) + val aggConstantFeatureFiltered = getNewAggregator(instancesConstantFeatureFiltered, + Vectors.dense(coefArrayFiltered ++ interceptArray), fitIntercept = true, isMultinomial = true) + + instancesConstantFeature.foreach(aggConstantFeature.add) + instancesConstantFeatureFiltered.foreach(aggConstantFeatureFiltered.add) // constant features should not affect gradient - def validateGradient(grad: Vector): Unit = { - assert(grad(0) === 0.0) - grad.toArray.foreach { gradientValue => - assert(!gradientValue.isNaN && - gradientValue > Double.NegativeInfinity && gradientValue < Double.PositiveInfinity) + def validateGradient(grad: Vector, gradFiltered: Vector, numCoefficientSets: Int): Unit = { + for (i <- 0 until numCoefficientSets) { + assert(grad(i) === 0.0) + assert(grad(numCoefficientSets + i) == gradFiltered(i)) } } - validateGradient(aggConstantFeature.gradient) + validateGradient(aggConstantFeature.gradient, aggConstantFeatureFiltered.gradient, 3) val binaryCoefArray = Array(1.0, 2.0) + val binaryCoefArrayFiltered = Array(2.0) val intercept = 1.0 val aggConstantFeatureBinary = getNewAggregator(binaryInstances, Vectors.dense(binaryCoefArray ++ Array(intercept)), fitIntercept = true, isMultinomial = false) - instances.foreach(aggConstantFeatureBinary.add) + val aggConstantFeatureBinaryFiltered = getNewAggregator(binaryInstancesFiltered, + Vectors.dense(binaryCoefArrayFiltered ++ Array(intercept)), fitIntercept = true, + isMultinomial = false) + binaryInstances.foreach(aggConstantFeatureBinary.add) + binaryInstancesFiltered.foreach(aggConstantFeatureBinaryFiltered.add) + // constant features should not affect gradient - validateGradient(aggConstantFeatureBinary.gradient) + validateGradient(aggConstantFeatureBinary.gradient, + aggConstantFeatureBinaryFiltered.gradient, 1) } }