Skip to content

Commit

Permalink
style checker
Browse files Browse the repository at this point in the history
  • Loading branch information
sethah committed Jun 14, 2017
1 parent a5b18c2 commit 6edd128
Show file tree
Hide file tree
Showing 6 changed files with 12 additions and 13 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -20,16 +20,16 @@ package org.apache.spark.ml.classification
import java.util.Locale

import scala.collection.mutable

import breeze.linalg.{DenseVector => BDV}
import breeze.optimize.{CachedDiffFunction, DiffFunction, LBFGS => BreezeLBFGS, LBFGSB => BreezeLBFGSB, OWLQN => BreezeOWLQN}
import org.apache.hadoop.fs.Path

import org.apache.spark.SparkException
import org.apache.spark.annotation.{Experimental, Since}
import org.apache.spark.broadcast.Broadcast
import org.apache.spark.internal.Logging
import org.apache.spark.ml.feature.Instance
import org.apache.spark.ml.linalg._
import org.apache.spark.ml.linalg.BLAS._
import org.apache.spark.ml.optim.aggregator.LogisticAggregator
import org.apache.spark.ml.optim.loss.{L2Regularization, RDDLossFunction}
import org.apache.spark.ml.param._
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
package org.apache.spark.ml.optim.loss

import breeze.optimize.DiffFunction

import org.apache.spark.ml.linalg._

/**
Expand All @@ -39,11 +40,11 @@ private[ml] trait DifferentiableRegularization[T] extends DiffFunction[T] {
* @param regParam The magnitude of the regularization.
* @param shouldApply A function (Int => Boolean) indicating whether a given index should have
* regularization applied to it.
* @param featuresStd Option indicating whether the regularization should be scaled by the standard
* deviation of the features.
* @param featuresStd Option for a function which maps coefficient index (column major) to the
* feature standard deviation. If `None`, no standardization is applied.
*/
private[ml] class L2Regularization(
val regParam: Double,
override val regParam: Double,
shouldApply: Int => Boolean,
featuresStd: Option[Int => Double]) extends DifferentiableRegularization[Vector] {

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -322,7 +322,7 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String

val getAggregatorFunc = new LeastSquaresAggregator(yStd, yMean, $(fitIntercept),
bcFeaturesStd, bcFeaturesMean)(_)
val getFeaturesStd = (j: Int) => featuresStd(j)
val getFeaturesStd = (j: Int) => if (j >=0 && j < numFeatures) featuresStd(j) else 0.0
val regularization = if (effectiveL2RegParam != 0.0) {
val shouldApply = (idx: Int) => idx >= 0 && idx < numFeatures
Some(new L2Regularization(effectiveL2RegParam, shouldApply,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import scala.collection.JavaConverters._
import scala.language.existentials
import scala.util.Random
import scala.util.control.Breaks._

import org.apache.spark.{SparkException, SparkFunSuite}
import org.apache.spark.ml.attribute.NominalAttribute
import org.apache.spark.ml.classification.LogisticRegressionSuite._
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@ import org.apache.spark.ml.classification.MultiClassSummarizer
import org.apache.spark.ml.feature.Instance
import org.apache.spark.ml.linalg.{BLAS, Vector, Vectors}
import org.apache.spark.ml.util.TestingUtils._
import org.apache.spark.mllib.stat.MultivariateOnlineSummarizer
import org.apache.spark.mllib.linalg.VectorImplicits._
import org.apache.spark.mllib.stat.MultivariateOnlineSummarizer

class DifferentiableLossAggregatorSuite extends SparkFunSuite {

Expand Down Expand Up @@ -162,7 +162,7 @@ object DifferentiableLossAggregatorSuite {
}

/** Get feature and label summarizers for provided data. */
def getRegressionSummarizers(
private[ml] def getRegressionSummarizers(
instances: Array[Instance]): (MultivariateOnlineSummarizer, MultivariateOnlineSummarizer) = {
val seqOp = (c: (MultivariateOnlineSummarizer, MultivariateOnlineSummarizer),
instance: Instance) =>
Expand All @@ -179,9 +179,8 @@ object DifferentiableLossAggregatorSuite {
}

/** Get feature and label summarizers for provided data. */
def getClassificationSummarizers(
instances: Array[Instance]):
(MultivariateOnlineSummarizer, MultiClassSummarizer) = {
private[ml] def getClassificationSummarizers(
instances: Array[Instance]): (MultivariateOnlineSummarizer, MultiClassSummarizer) = {
val seqOp = (c: (MultivariateOnlineSummarizer, MultiClassSummarizer),
instance: Instance) =>
(c._1.add(instance.features, instance.weight),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,6 @@ class LogisticAggregatorSuite extends SparkFunSuite with MLlibTestSparkContext {
instances.foreach(aggIntercept.add)
instances.foreach(aggNoIntercept.add)

// least squares agg does not include intercept in its gradient array
assert(aggIntercept.gradient.size === (numFeatures + 1) * numClasses)
assert(aggNoIntercept.gradient.size === numFeatures * numClasses)
}
Expand All @@ -116,7 +115,6 @@ class LogisticAggregatorSuite extends SparkFunSuite with MLlibTestSparkContext {
binaryInstances.foreach(aggIntercept.add)
binaryInstances.foreach(aggNoIntercept.add)

// least squares agg does not include intercept in its gradient array
assert(aggIntercept.gradient.size === numFeatures + 1)
assert(aggNoIntercept.gradient.size === numFeatures)
}
Expand Down

0 comments on commit 6edd128

Please sign in to comment.