-
Notifications
You must be signed in to change notification settings - Fork 28.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[SPARK-2852][MLLIB] Separate model from IDF/StandardScaler algorithms #1814
Changes from 2 commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -35,38 +35,47 @@ import org.apache.spark.rdd.RDD | |
* @param withStd True by default. Scales the data to unit standard deviation. | ||
*/ | ||
@Experimental | ||
class StandardScaler(withMean: Boolean, withStd: Boolean) extends VectorTransformer { | ||
class StandardScaler(withMean: Boolean, withStd: Boolean) { | ||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This class is only used for keeping the state of withMean, and withStd, is it possible to move those states to fit function by overloading, and make it as object? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The current API is more consistent with others like |
||
def this() = this(false, true) | ||
|
||
require(withMean || withStd, s"withMean and withStd both equal to false. Doing nothing.") | ||
|
||
private var mean: BV[Double] = _ | ||
private var factor: BV[Double] = _ | ||
|
||
/** | ||
* Computes the mean and variance and stores as a model to be used for later scaling. | ||
* | ||
* @param data The data used to compute the mean and variance to build the transformation model. | ||
* @return This StandardScalar object. | ||
* @return a StandardScalarModel | ||
*/ | ||
def fit(data: RDD[Vector]): this.type = { | ||
def fit(data: RDD[Vector]): StandardScalerModel = { | ||
val summary = data.treeAggregate(new MultivariateOnlineSummarizer)( | ||
(aggregator, data) => aggregator.add(data), | ||
(aggregator1, aggregator2) => aggregator1.merge(aggregator2)) | ||
|
||
mean = summary.mean.toBreeze | ||
factor = summary.variance.toBreeze | ||
require(mean.length == factor.length) | ||
val mean = summary.mean.toBreeze | ||
val factor = summary.variance.toBreeze | ||
require(mean.size == factor.size) | ||
|
||
var i = 0 | ||
while (i < factor.length) { | ||
while (i < factor.size) { | ||
factor(i) = if (factor(i) != 0.0) 1.0 / math.sqrt(factor(i)) else 0.0 | ||
i += 1 | ||
} | ||
|
||
this | ||
new StandardScalerModel(withMean, withStd, mean, factor) | ||
} | ||
} | ||
|
||
/** | ||
* :: Experimental :: | ||
* Represents a StandardScaler model that can transform vectors. | ||
*/ | ||
@Experimental | ||
class StandardScalerModel private[mllib] ( | ||
val withMean: Boolean, | ||
val withStd: Boolean, | ||
val mean: BV[Double], | ||
val factor: BV[Double]) extends VectorTransformer { | ||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Since users may want to know the variance of the training set, should we have constructor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done. |
||
/** | ||
* Applies standardization transformation on a vector. | ||
|
@@ -81,7 +90,7 @@ class StandardScaler(withMean: Boolean, withStd: Boolean) extends VectorTransfor | |
"Haven't learned column summary statistics yet. Call fit first.") | ||
} | ||
|
||
require(vector.size == mean.length) | ||
require(vector.size == mean.size) | ||
|
||
if (withMean) { | ||
vector.toBreeze match { | ||
|
@@ -115,5 +124,4 @@ class StandardScaler(withMean: Boolean, withStd: Boolean) extends VectorTransfor | |
vector | ||
} | ||
} | ||
|
||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The following exception is used for unsupported vector in appendBias and StandardScaler, maybe we could have a global definition of this in util.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We might want to use different error messages. In that case, having a util function doesn't save us much.