Skip to content

Commit

Permalink
update params
Browse files Browse the repository at this point in the history
  • Loading branch information
mengxr committed Nov 7, 2014
1 parent fe0ee92 commit bab3e5b
Show file tree
Hide file tree
Showing 6 changed files with 183 additions and 147 deletions.
34 changes: 27 additions & 7 deletions mllib/src/main/scala/org/apache/spark/ml/Estimator.scala
Original file line number Diff line number Diff line change
Expand Up @@ -25,29 +25,28 @@ import org.apache.spark.sql.SchemaRDD
import org.apache.spark.sql.api.java.JavaSchemaRDD

/**
* Abstract class for estimators that fits models to data.
* Abstract class for estimators that fit models to data.
*/
abstract class Estimator[M <: Model] extends PipelineStage with Params {

/**
* Fits a single model to the input data with optional parameters.
*
* @param dataset input dataset
* @param paramPairs optional list of param pairs, overwrite embedded params
* @param paramPairs optional list of param pairs (overwrite embedded params)
* @return fitted model
*/
@varargs
def fit(dataset: SchemaRDD, paramPairs: ParamPair[_]*): M = {
val map = new ParamMap()
paramPairs.foreach(map.put(_))
val map = new ParamMap().put(paramPairs: _*)
fit(dataset, map)
}

/**
* Fits a single model to the input data with provided parameter map.
*
* @param dataset input dataset
* @param paramMap parameters
* @param paramMap parameter map
* @return fitted model
*/
def fit(dataset: SchemaRDD, paramMap: ParamMap): M
Expand All @@ -61,27 +60,48 @@ abstract class Estimator[M <: Model] extends PipelineStage with Params {
* @param paramMaps an array of parameter maps
* @return fitted models, matching the input parameter maps
*/
def fit(dataset: SchemaRDD, paramMaps: Array[ParamMap]): Seq[M] = { // how to return an array?
def fit(dataset: SchemaRDD, paramMaps: Array[ParamMap]): Seq[M] = {
paramMaps.map(fit(dataset, _))
}

// Java-friendly versions of fit.

/**
* Fits a single model to the input data with optional parameters.
*
* @param dataset input dataset
* @param paramPairs optional list of param pairs (overwrite embedded params)
* @return fitted model
*/
@varargs
def fit(dataset: JavaSchemaRDD, paramPairs: ParamPair[_]*): M = {
fit(dataset.schemaRDD, paramPairs: _*)
}

/**
* Fits a single model to the input data with provided parameter map.
*
* @param dataset input dataset
* @param paramMap parameter map
* @return fitted model
*/
def fit(dataset: JavaSchemaRDD, paramMap: ParamMap): M = {
fit(dataset.schemaRDD, paramMap)
}

/**
* Fits multiple models to the input data with multiple sets of parameters.
*
* @param dataset input dataset
* @param paramMaps an array of parameter maps
* @return fitted models, matching the input parameter maps
*/
def fit(dataset: JavaSchemaRDD, paramMaps: Array[ParamMap]): java.util.List[M] = {
fit(dataset.schemaRDD, paramMaps).asJava
}

/**
* Parameters for the output model.
*/
def modelParams: Params = Params.empty
val modelParams: Params = Params.empty
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ package org.apache.spark.ml.example
import com.github.fommil.netlib.F2jBLAS

import org.apache.spark.ml._
import org.apache.spark.ml.param.{Param, ParamMap, Params}
import org.apache.spark.ml.param.{IntParam, Param, ParamMap, Params}
import org.apache.spark.mllib.util.MLUtils
import org.apache.spark.sql.SchemaRDD

Expand All @@ -44,7 +44,8 @@ class CrossValidator extends Estimator[CrossValidatorModel] with Params {
def setEvaluator(value: Evaluator): this.type = { set(evaluator, value); this }
def getEvaluator: Evaluator = get(evaluator)

val numFolds: Param[Int] = new Param(this, "numFolds", "number of folds for cross validation", 3)
val numFolds: Param[Int] =
new IntParam(this, "numFolds", "number of folds for cross validation", Some(3))
def setNumFolds(value: Int): this.type = { set(numFolds, value); this }
def getNumFolds: Int = get(numFolds)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,9 @@ class LogisticRegression extends Estimator[LogisticRegressionModel]
.setNumIterations(maxIter)
val lrm = new LogisticRegressionModel(lr.run(instances).weights)
instances.unpersist()
// copy model params
Params.copyValues(modelParams, lrm)
if (!lrm.paramMap.contains(lrm.featuresCol) && map.contains(lrm.featuresCol)) {
if (!lrm.isSet(lrm.featuresCol) && map.contains(lrm.featuresCol)) {
lrm.setFeaturesCol(featuresCol)
}
lrm
Expand All @@ -69,8 +70,8 @@ class LogisticRegression extends Estimator[LogisticRegressionModel]
* Validates parameters specified by the input parameter map.
* Raises an exception if any parameter belongs to this object is invalid.
*/
override def validateParams(paramMap: ParamMap): Unit = {
super.validateParams(paramMap)
override def validate(paramMap: ParamMap): Unit = {
super.validate(paramMap)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ class StandardScaler extends Estimator[StandardScalerModel] with HasInputCol {
val scaler = new feature.StandardScaler().fit(input)
val model = new StandardScalerModel(scaler)
Params.copyValues(modelParams, model)
if (!model.paramMap.contains(model.inputCol)) {
if (!model.isSet(model.inputCol)) {
model.setInputCol(inputCol)
}
model
Expand Down
Loading

0 comments on commit bab3e5b

Please sign in to comment.