Skip to content

Commit

Permalink
remove modelParams
Browse files Browse the repository at this point in the history
add a simple text classification pipeline
  • Loading branch information
mengxr committed Nov 7, 2014
1 parent b95c408 commit 7772430
Show file tree
Hide file tree
Showing 10 changed files with 228 additions and 79 deletions.
5 changes: 0 additions & 5 deletions mllib/src/main/scala/org/apache/spark/ml/Estimator.scala
Original file line number Diff line number Diff line change
Expand Up @@ -99,9 +99,4 @@ abstract class Estimator[M <: Model] extends PipelineStage with Params {
def fit(dataset: JavaSchemaRDD, paramMaps: Array[ParamMap]): java.util.List[M] = {
fit(dataset.schemaRDD, paramMaps).asJava
}

/**
* Parameters for the output model.
*/
val modelParams: Params = Params.empty
}
15 changes: 12 additions & 3 deletions mllib/src/main/scala/org/apache/spark/ml/Model.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,19 @@

package org.apache.spark.ml

import org.apache.spark.ml.param.ParamMap

/**
* A trained model.
* A fitted model.
*/
abstract class Model extends Transformer {
// def parent: Estimator
// def trainingParameters: ParamMap
/**
* The parent estimator that produced this model.
*/
val parent: Estimator[_]

/**
* Fitting parameters, such that parent.fit(..., trainingParamMap) could reproduce the model.
*/
val fittingParamMap: ParamMap
}
7 changes: 5 additions & 2 deletions mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala
Original file line number Diff line number Diff line change
Expand Up @@ -68,14 +68,17 @@ class Pipeline extends Estimator[PipelineModel] {
}
}

new PipelineModel(transformers.toArray)
new PipelineModel(this, map, transformers.toArray)
}
}

/**
* Represents a compiled pipeline.
*/
class PipelineModel(val transformers: Array[Transformer]) extends Model {
class PipelineModel(
override val parent: Pipeline,
override val fittingParamMap: ParamMap,
val transformers: Array[Transformer]) extends Model {

override def transform(dataset: SchemaRDD, paramMap: ParamMap): SchemaRDD = {
transformers.foldLeft(dataset) { (dataset, transformer) =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,30 +25,39 @@ import org.apache.spark.ml.param.{IntParam, Param, ParamMap, Params}
import org.apache.spark.mllib.util.MLUtils
import org.apache.spark.sql.SchemaRDD

class CrossValidator extends Estimator[CrossValidatorModel] with Params with Logging {

private val f2jBLAS = new F2jBLAS

/**
* Params for [[CrossValidator]] and [[CrossValidatorModel]].
*/
trait CrossValidatorParams extends Params {
val estimator: Param[Estimator[_]] = new Param(this, "estimator", "estimator for selection")
def setEstimator(value: Estimator[_]): this.type = { set(estimator, value); this }
def getEstimator: Estimator[_] = get(estimator)

val estimatorParamMaps: Param[Array[ParamMap]] =
new Param(this, "estimatorParamMaps", "param maps for the estimator")
def getEstimatorParamMaps: Array[ParamMap] = get(estimatorParamMaps)
def setEstimatorParamMaps(value: Array[ParamMap]): this.type = {
set(estimatorParamMaps, value)
this
}

val evaluator: Param[Evaluator] = new Param(this, "evaluator", "evaluator for selection")
def setEvaluator(value: Evaluator): this.type = { set(evaluator, value); this }
def getEvaluator: Evaluator = get(evaluator)

val numFolds: IntParam =
new IntParam(this, "numFolds", "number of folds for cross validation", Some(3))
def setNumFolds(value: Int): this.type = { set(numFolds, value); this }
def getNumFolds: Int = get(numFolds)
}

/**
* K-fold cross validation.
*/
class CrossValidator extends Estimator[CrossValidatorModel] with CrossValidatorParams with Logging {

private val f2jBLAS = new F2jBLAS

def setEstimator(value: Estimator[_]): this.type = { set(estimator, value); this }
def setEstimatorParamMaps(value: Array[ParamMap]): this.type = {
set(estimatorParamMaps, value)
this
}
def setEvaluator(value: Evaluator): this.type = { set(evaluator, value); this }
def setNumFolds(value: Int): this.type = { set(numFolds, value); this }

/**
* Fits a single model to the input data with provided parameter map.
Expand All @@ -74,7 +83,7 @@ class CrossValidator extends Estimator[CrossValidatorModel] with Params with Log
logDebug(s"Train split $splitIndex with multiple sets of parameters.")
val models = est.fit(trainingDataset, epm).asInstanceOf[Seq[Model]]
var i = 0
while(i < numModels) {
while (i < numModels) {
val metric = eval.evaluate(models(i).transform(validationDataset, epm(i)), map)
logDebug(s"Got metric $metric for model trained with ${epm(i)}.")
metrics(i) += metric
Expand All @@ -86,12 +95,21 @@ class CrossValidator extends Estimator[CrossValidatorModel] with Params with Log
val (bestMetric, bestIndex) = metrics.zipWithIndex.maxBy(_._1)
logInfo("Best set of parameters:\n" + epm(bestIndex))
val bestModel = est.fit(dataset, epm(bestIndex)).asInstanceOf[Model]
new CrossValidatorModel(bestModel, bestMetric / map(numFolds))
val cvModel = new CrossValidatorModel(this, map, bestModel, bestMetric / map(numFolds))
Params.copyValues(this, cvModel)
cvModel
}
}

class CrossValidatorModel(bestModel: Model, metric: Double) extends Model {
def transform(dataset: SchemaRDD, paramMap: ParamMap): SchemaRDD = {
/**
* Model from k-fold cross validation.
*/
class CrossValidatorModel private[ml] (
override val parent: CrossValidator,
override val fittingParamMap: ParamMap,
bestModel: Model,
metric: Double) extends Model with CrossValidatorParams {
override def transform(dataset: SchemaRDD, paramMap: ParamMap): SchemaRDD = {
bestModel.transform(dataset, paramMap)
}
}
29 changes: 29 additions & 0 deletions mllib/src/main/scala/org/apache/spark/ml/example/HashingTF.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
package org.apache.spark.ml.example

import org.apache.spark.ml.Transformer
import org.apache.spark.ml.param.{HasInputCol, HasOutputCol, IntParam, ParamMap}
import org.apache.spark.mllib.feature
import org.apache.spark.mllib.linalg.Vector
import org.apache.spark.sql.SchemaRDD
import org.apache.spark.sql.catalyst.analysis.Star
import org.apache.spark.sql.catalyst.dsl._

class HashingTF extends Transformer with HasInputCol with HasOutputCol {

def setInputCol(value: String) = { set(inputCol, value); this }
def setOutputCol(value: String) = { set(outputCol, value); this }

val numFeatures = new IntParam(this, "numFeatures", "number of features", Some(1 << 18))
def setNumFeatures(value: Int) = { set(numFeatures, value); this }
def getNumFeatures: Int = get(numFeatures)

override def transform(dataset: SchemaRDD, paramMap: ParamMap): SchemaRDD = {
import dataset.sqlContext._
val map = this.paramMap ++ paramMap
val hashingTF = new feature.HashingTF(map(numFeatures))
val t: Iterable[_] => Vector = (doc) => {
hashingTF.transform(doc)
}
dataset.select(Star(None), t.call(map(inputCol).attr) as map(outputCol))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -27,22 +27,27 @@ import org.apache.spark.sql.catalyst.analysis.Star
import org.apache.spark.sql.catalyst.dsl._
import org.apache.spark.sql.catalyst.expressions.Row

/**
* Params for logistic regression.
*/
trait LogisticRegressionParams extends Params with HasRegParam with HasMaxIter with HasLabelCol
with HasThreshold with HasFeaturesCol with HasScoreCol with HasPredictionCol

/**
* Logistic regression.
*/
class LogisticRegression extends Estimator[LogisticRegressionModel]
with HasRegParam with HasMaxIter with HasLabelCol with HasFeaturesCol {
class LogisticRegression extends Estimator[LogisticRegressionModel] with LogisticRegressionParams {

setRegParam(0.1)
setMaxIter(100)

def setRegParam(value: Double): this.type = { set(regParam, value); this }
def setMaxIter(value: Int): this.type = { set(maxIter, value); this }
def setLabelCol(value: String): this.type = { set(labelCol, value); this }
def setThreshold(value: Double): this.type = { set(threshold, value); this }
def setFeaturesCol(value: String): this.type = { set(featuresCol, value); this }

override final val modelParams: LogisticRegressionModelParams =
new LogisticRegressionModelParams {}
def setScoreCol(value: String): this.type = { set(scoreCol, value); this }
def setPredictionCol(value: String): this.type = { set(predictionCol, value); this }

override def fit(dataset: SchemaRDD, paramMap: ParamMap): LogisticRegressionModel = {
import dataset.sqlContext._
Expand All @@ -55,13 +60,10 @@ class LogisticRegression extends Estimator[LogisticRegressionModel]
lr.optimizer
.setRegParam(map(regParam))
.setNumIterations(map(maxIter))
val lrm = new LogisticRegressionModel(lr.run(instances).weights)
val lrm = new LogisticRegressionModel(this, map, lr.run(instances).weights)
instances.unpersist()
// copy model params
Params.copyValues(modelParams, lrm)
if (!lrm.isSet(lrm.featuresCol) && map.contains(lrm.featuresCol)) {
lrm.setFeaturesCol(map(featuresCol))
}
Params.copyValues(this, lrm)
lrm
}

Expand All @@ -74,19 +76,17 @@ class LogisticRegression extends Estimator[LogisticRegressionModel]
}
}

trait LogisticRegressionModelParams extends Params with HasThreshold with HasFeaturesCol
with HasScoreCol with HasPredictionCol {
class LogisticRegressionModel private[ml] (
override val parent: LogisticRegression,
override val fittingParamMap: ParamMap,
val weights: Vector) extends Model with LogisticRegressionParams {

setThreshold(0.5)

def setThreshold(value: Double): this.type = { set(threshold, value); this }
def setFeaturesCol(value: String): this.type = { set(featuresCol, value); this }
def setScoreCol(value: String): this.type = { set(scoreCol, value); this }
def setPredictionCol(value: String): this.type = { set(predictionCol, value); this }
}

class LogisticRegressionModel(
val weights: Vector)
extends Model with LogisticRegressionModelParams {

setThreshold(0.5)

override def transform(dataset: SchemaRDD, paramMap: ParamMap): SchemaRDD = {
import dataset.sqlContext._
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,15 @@ import org.apache.spark.sql.catalyst.analysis.Star
import org.apache.spark.sql.catalyst.dsl._
import org.apache.spark.sql.catalyst.expressions.Row

class StandardScaler extends Estimator[StandardScalerModel] with HasInputCol {
/**
* Params for [[StandardScaler]] and [[StandardScalerModel]].
*/
trait StandardScalerParams extends Params with HasInputCol with HasOutputCol

def setInputCol(value: String): this.type = { set(inputCol, value); this }
class StandardScaler extends Estimator[StandardScalerModel] with StandardScalerParams {

override val modelParams: StandardScalerModelParams = new StandardScalerModelParams {}
def setInputCol(value: String): this.type = { set(inputCol, value); this }
def setOutputCol(value: String): this.type = { set(outputCol, value); this }

override def fit(dataset: SchemaRDD, paramMap: ParamMap): StandardScalerModel = {
import dataset.sqlContext._
Expand All @@ -40,22 +44,19 @@ class StandardScaler extends Estimator[StandardScalerModel] with HasInputCol {
v
}
val scaler = new feature.StandardScaler().fit(input)
val model = new StandardScalerModel(scaler)
Params.copyValues(modelParams, model)
if (!model.isSet(model.inputCol)) {
model.setInputCol(map(inputCol))
}
val model = new StandardScalerModel(this, map, scaler)
Params.copyValues(this, model)
model
}
}

trait StandardScalerModelParams extends Params with HasInputCol with HasOutputCol {
class StandardScalerModel private[ml] (
override val parent: StandardScaler,
override val fittingParamMap: ParamMap,
scaler: feature.StandardScalerModel) extends Model with StandardScalerParams {

def setInputCol(value: String): this.type = { set(inputCol, value); this }
def setOutputCol(value: String): this.type = { set(outputCol, value); this }
}

class StandardScalerModel private[ml] (
scaler: feature.StandardScalerModel) extends Model with StandardScalerModelParams {

override def transform(dataset: SchemaRDD, paramMap: ParamMap): SchemaRDD = {
import dataset.sqlContext._
Expand Down
42 changes: 42 additions & 0 deletions mllib/src/main/scala/org/apache/spark/ml/example/Tokenizer.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.ml.example

import org.apache.spark.ml.Transformer
import org.apache.spark.ml.param.{HasInputCol, HasOutputCol, ParamMap}
import org.apache.spark.sql.SchemaRDD
import org.apache.spark.sql.catalyst.analysis.Star
import org.apache.spark.sql.catalyst.dsl._

/**
* A simple tokenizer that splits input string by white spaces.
*/
class Tokenizer extends Transformer with HasInputCol with HasOutputCol {

def setInputCol(value: String) = { set(inputCol, value); this }
def setOutputCol(value: String) = { set(outputCol, value); this }

override def transform(dataset: SchemaRDD, paramMap: ParamMap): SchemaRDD = {
import dataset.sqlContext._
val map = this.paramMap ++ paramMap
val split: String => Seq[String] = (text) => {
text.split("\\s").toSeq
}
dataset.select(Star(None), split.call(map(inputCol).attr) as map(outputCol))
}
}
Loading

0 comments on commit 7772430

Please sign in to comment.