diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala index 0c6a37bab0aad..9c131a41850cc 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala @@ -27,7 +27,7 @@ import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared._ import org.apache.spark.ml.util._ import org.apache.spark.mllib.feature -import org.apache.spark.mllib.feature.ChiSqSelectorType +import org.apache.spark.mllib.feature.{ChiSqSelector => OldChiSqSelector} import org.apache.spark.mllib.linalg.{Vectors => OldVectors} import org.apache.spark.mllib.regression.{LabeledPoint => OldLabeledPoint} import org.apache.spark.rdd.RDD @@ -44,7 +44,9 @@ private[feature] trait ChiSqSelectorParams extends Params /** * Number of features that selector will select (ordered by statistic value descending). If the * number of features is less than numTopFeatures, then this will select all features. + * Only applicable when selectorType = "kbest". * The default value of numTopFeatures is 50. + * * @group param */ final val numTopFeatures = new IntParam(this, "numTopFeatures", @@ -56,6 +58,11 @@ private[feature] trait ChiSqSelectorParams extends Params /** @group getParam */ def getNumTopFeatures: Int = $(numTopFeatures) + /** + * Percentile of features that selector will select, ordered by statistics value descending. + * Only applicable when selectorType = "percentile". + * Default value is 0.1. + */ final val percentile = new DoubleParam(this, "percentile", "Percentile of features that selector will select, ordered by statistics value descending.", ParamValidators.inRange(0, 1)) @@ -64,8 +71,12 @@ private[feature] trait ChiSqSelectorParams extends Params /** @group getParam */ def getPercentile: Double = $(percentile) - final val alpha = new DoubleParam(this, "alpha", - "The highest p-value for features to be kept.", + /** + * The highest p-value for features to be kept. + * Only applicable when selectorType = "fpr". + * Default value is 0.05. + */ + final val alpha = new DoubleParam(this, "alpha", "The highest p-value for features to be kept.", ParamValidators.inRange(0, 1)) setDefault(alpha -> 0.05) @@ -73,29 +84,27 @@ private[feature] trait ChiSqSelectorParams extends Params def getAlpha: Double = $(alpha) /** - * The ChiSqSelector supports KBest, Percentile, FPR selection, - * which is the same as ChiSqSelectorType defined in MLLIB. - * when call setNumTopFeatures, the selectorType is set to KBest - * when call setPercentile, the selectorType is set to Percentile - * when call setAlpha, the selectorType is set to FPR + * The selector type of the ChisqSelector. + * Supported options: "kbest" (default), "percentile" and "fpr". */ final val selectorType = new Param[String](this, "selectorType", - "ChiSqSelector Type: KBest, Percentile, FPR") - setDefault(selectorType -> ChiSqSelectorType.KBest.toString) + "The selector type of the ChisqSelector. " + + "Supported options: kbest (default), percentile and fpr.", + ParamValidators.inArray[String](OldChiSqSelector.supportedSelectorTypes.toArray)) + setDefault(selectorType -> OldChiSqSelector.KBest) /** @group getParam */ - def getChiSqSelectorType: String = $(selectorType) + def getSelectorType: String = $(selectorType) } /** * Chi-Squared feature selection, which selects categorical features to use for predicting a * categorical label. - * The selector supports three selection methods: `KBest`, `Percentile` and `FPR`. - * `KBest` chooses the `k` top features according to a chi-squared test. - * `Percentile` is similar but chooses a fraction of all features instead of a fixed number. - * `FPR` chooses all features whose false positive rate meets some threshold. - * By default, the selection method is `KBest`, the default number of top features is 50. - * User can use setNumTopFeatures, setPercentile and setAlpha to set different selection methods. + * The selector supports three selection methods: `kbest`, `percentile` and `fpr`. + * `kbest` chooses the `k` top features according to a chi-squared test. + * `percentile` is similar but chooses a fraction of all features instead of a fixed number. + * `fpr` chooses all features whose false positive rate meets some threshold. + * By default, the selection method is `kbest`, the default number of top features is 50. */ @Since("1.6.0") final class ChiSqSelector @Since("1.6.0") (@Since("1.6.0") override val uid: String) @@ -104,24 +113,21 @@ final class ChiSqSelector @Since("1.6.0") (@Since("1.6.0") override val uid: Str @Since("1.6.0") def this() = this(Identifiable.randomUID("chiSqSelector")) + /** @group setParam */ + @Since("2.1.0") + def setSelectorType(value: String): this.type = set(selectorType, value) + /** @group setParam */ @Since("1.6.0") - def setNumTopFeatures(value: Int): this.type = { - set(selectorType, ChiSqSelectorType.KBest.toString) - set(numTopFeatures, value) - } + def setNumTopFeatures(value: Int): this.type = set(numTopFeatures, value) + /** @group setParam */ @Since("2.1.0") - def setPercentile(value: Double): this.type = { - set(selectorType, ChiSqSelectorType.Percentile.toString) - set(percentile, value) - } + def setPercentile(value: Double): this.type = set(percentile, value) + /** @group setParam */ @Since("2.1.0") - def setAlpha(value: Double): this.type = { - set(selectorType, ChiSqSelectorType.FPR.toString) - set(alpha, value) - } + def setAlpha(value: Double): this.type = set(alpha, value) /** @group setParam */ @Since("1.6.0") @@ -143,23 +149,23 @@ final class ChiSqSelector @Since("1.6.0") (@Since("1.6.0") override val uid: Str case Row(label: Double, features: Vector) => OldLabeledPoint(label, OldVectors.fromML(features)) } - var selector = new feature.ChiSqSelector() - ChiSqSelectorType.withName($(selectorType)) match { - case ChiSqSelectorType.KBest => - selector.setNumTopFeatures($(numTopFeatures)) - case ChiSqSelectorType.Percentile => - selector.setPercentile($(percentile)) - case ChiSqSelectorType.FPR => - selector.setAlpha($(alpha)) - case errorType => - throw new IllegalStateException(s"Unknown ChiSqSelector Type: $errorType") - } + val selector = new feature.ChiSqSelector() + .setSelectorType($(selectorType)) + .setNumTopFeatures($(numTopFeatures)) + .setPercentile($(percentile)) + .setAlpha($(alpha)) val model = selector.fit(input) copyValues(new ChiSqSelectorModel(uid, model).setParent(this)) } @Since("1.6.0") override def transformSchema(schema: StructType): StructType = { + val otherPairs = OldChiSqSelector.supportedTypeAndParamPairs.filter(_._1 != $(selectorType)) + otherPairs.foreach { case (_, paramName: String) => + if (isSet(getParam(paramName))) { + logWarning(s"Param $paramName will take no effect when selector type = ${$(selectorType)}.") + } + } SchemaUtils.checkColumnType(schema, $(featuresCol), new VectorUDT) SchemaUtils.checkNumericType(schema, $(labelCol)) SchemaUtils.appendColumn(schema, $(outputCol), new VectorUDT) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala index 5cffbf0892888..904000f50d0a2 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala @@ -629,35 +629,23 @@ private[python] class PythonMLLibAPI extends Serializable { } /** - * Java stub for ChiSqSelector.fit() when the seletion type is KBest. This stub returns a + * Java stub for ChiSqSelector.fit(). This stub returns a * handle to the Java object instead of the content of the Java object. * Extra care needs to be taken in the Python code to ensure it gets freed on * exit; see the Py4J documentation. */ - def fitChiSqSelectorKBest(numTopFeatures: Int, - data: JavaRDD[LabeledPoint]): ChiSqSelectorModel = { - new ChiSqSelector().setNumTopFeatures(numTopFeatures).fit(data.rdd) - } - - /** - * Java stub for ChiSqSelector.fit() when the selection type is Percentile. This stub returns a - * handle to the Java object instead of the content of the Java object. - * Extra care needs to be taken in the Python code to ensure it gets freed on - * exit; see the Py4J documentation. - */ - def fitChiSqSelectorPercentile(percentile: Double, - data: JavaRDD[LabeledPoint]): ChiSqSelectorModel = { - new ChiSqSelector().setPercentile(percentile).fit(data.rdd) - } - - /** - * Java stub for ChiSqSelector.fit() when the selection type is FPR. This stub returns a - * handle to the Java object instead of the content of the Java object. - * Extra care needs to be taken in the Python code to ensure it gets freed on - * exit; see the Py4J documentation. - */ - def fitChiSqSelectorFPR(alpha: Double, data: JavaRDD[LabeledPoint]): ChiSqSelectorModel = { - new ChiSqSelector().setAlpha(alpha).fit(data.rdd) + def fitChiSqSelector( + selectorType: String, + numTopFeatures: Int, + percentile: Double, + alpha: Double, + data: JavaRDD[LabeledPoint]): ChiSqSelectorModel = { + new ChiSqSelector() + .setSelectorType(selectorType) + .setNumTopFeatures(numTopFeatures) + .setPercentile(percentile) + .setAlpha(alpha) + .fit(data.rdd) } /** diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala index f68a017184b21..0f7c6e8bc04bb 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala @@ -32,12 +32,6 @@ import org.apache.spark.rdd.RDD import org.apache.spark.SparkContext import org.apache.spark.sql.{Row, SparkSession} -@Since("2.1.0") -private[spark] object ChiSqSelectorType extends Enumeration { - type SelectorType = Value - val KBest, Percentile, FPR = Value -} - /** * Chi Squared selector model. * @@ -166,19 +160,18 @@ object ChiSqSelectorModel extends Loader[ChiSqSelectorModel] { /** * Creates a ChiSquared feature selector. - * The selector supports three selection methods: `KBest`, `Percentile` and `FPR`. - * `KBest` chooses the `k` top features according to a chi-squared test. - * `Percentile` is similar but chooses a fraction of all features instead of a fixed number. - * `FPR` chooses all features whose false positive rate meets some threshold. - * By default, the selection method is `KBest`, the default number of top features is 50. - * User can use setNumTopFeatures, setPercentile and setAlpha to set different selection methods. + * The selector supports three selection methods: `kbest`, `percentile` and `fpr`. + * `kbest` chooses the `k` top features according to a chi-squared test. + * `percentile` is similar but chooses a fraction of all features instead of a fixed number. + * `fpr` chooses all features whose false positive rate meets some threshold. + * By default, the selection method is `kbest`, the default number of top features is 50. */ @Since("1.3.0") class ChiSqSelector @Since("2.1.0") () extends Serializable { var numTopFeatures: Int = 50 var percentile: Double = 0.1 var alpha: Double = 0.05 - var selectorType = ChiSqSelectorType.KBest + var selectorType = ChiSqSelector.KBest /** * The is the same to call this() and setNumTopFeatures(numTopFeatures) @@ -192,7 +185,6 @@ class ChiSqSelector @Since("2.1.0") () extends Serializable { @Since("1.6.0") def setNumTopFeatures(value: Int): this.type = { numTopFeatures = value - selectorType = ChiSqSelectorType.KBest this } @@ -200,7 +192,6 @@ class ChiSqSelector @Since("2.1.0") () extends Serializable { def setPercentile(value: Double): this.type = { require(0.0 <= value && value <= 1.0, "Percentile must be in [0,1]") percentile = value - selectorType = ChiSqSelectorType.Percentile this } @@ -208,12 +199,13 @@ class ChiSqSelector @Since("2.1.0") () extends Serializable { def setAlpha(value: Double): this.type = { require(0.0 <= value && value <= 1.0, "Alpha must be in [0,1]") alpha = value - selectorType = ChiSqSelectorType.FPR this } @Since("2.1.0") - def setChiSqSelectorType(value: ChiSqSelectorType.Value): this.type = { + def setSelectorType(value: String): this.type = { + require(ChiSqSelector.supportedSelectorTypes.toSeq.contains(value), + s"ChiSqSelector Type: $value was not supported.") selectorType = value this } @@ -230,11 +222,11 @@ class ChiSqSelector @Since("2.1.0") () extends Serializable { val chiSqTestResult = Statistics.chiSqTest(data) .zipWithIndex.sortBy { case (res, _) => -res.statistic } val features = selectorType match { - case ChiSqSelectorType.KBest => chiSqTestResult + case ChiSqSelector.KBest => chiSqTestResult .take(numTopFeatures) - case ChiSqSelectorType.Percentile => chiSqTestResult + case ChiSqSelector.Percentile => chiSqTestResult .take((chiSqTestResult.length * percentile).toInt) - case ChiSqSelectorType.FPR => chiSqTestResult + case ChiSqSelector.FPR => chiSqTestResult .filter{ case (res, _) => res.pValue < alpha } case errorType => throw new IllegalStateException(s"Unknown ChiSqSelector Type: $errorType") @@ -244,3 +236,22 @@ class ChiSqSelector @Since("2.1.0") () extends Serializable { } } +@Since("2.1.0") +object ChiSqSelector { + + /** String name for `kbest` selector type. */ + private[spark] val KBest: String = "kbest" + + /** String name for `percentile` selector type. */ + private[spark] val Percentile: String = "percentile" + + /** String name for `fpr` selector type. */ + private[spark] val FPR: String = "fpr" + + /** Set of selector type and param pairs that ChiSqSelector supports. */ + private[spark] val supportedTypeAndParamPairs = Set(KBest -> "numTopFeatures", + Percentile -> "percentile", FPR -> "alpha") + + /** Set of selector types that ChiSqSelector supports. */ + private[spark] val supportedSelectorTypes = supportedTypeAndParamPairs.map(_._1) +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/ChiSqSelectorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/ChiSqSelectorSuite.scala index e0293dbc4b0b2..6b56e4200250c 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/ChiSqSelectorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/ChiSqSelectorSuite.scala @@ -50,6 +50,7 @@ class ChiSqSelectorSuite extends SparkFunSuite with MLlibTestSparkContext .toDF("label", "data", "preFilteredData") val selector = new ChiSqSelector() + .setSelectorType("kbest") .setNumTopFeatures(1) .setFeaturesCol("data") .setLabelCol("label") @@ -60,12 +61,28 @@ class ChiSqSelectorSuite extends SparkFunSuite with MLlibTestSparkContext assert(vec1 ~== vec2 absTol 1e-1) } - selector.setPercentile(0.34).fit(df).transform(df) - .select("filtered", "preFilteredData").collect().foreach { - case Row(vec1: Vector, vec2: Vector) => - assert(vec1 ~== vec2 absTol 1e-1) - } + selector.setSelectorType("percentile").setPercentile(0.34).fit(df).transform(df) + .select("filtered", "preFilteredData").collect().foreach { + case Row(vec1: Vector, vec2: Vector) => + assert(vec1 ~== vec2 absTol 1e-1) + } + + val preFilteredData2 = Seq( + Vectors.dense(8.0, 7.0), + Vectors.dense(0.0, 9.0), + Vectors.dense(0.0, 9.0), + Vectors.dense(8.0, 9.0) + ) + val df2 = sc.parallelize(data.zip(preFilteredData2)) + .map(x => (x._1.label, x._1.features, x._2)) + .toDF("label", "data", "preFilteredData") + + selector.setSelectorType("fpr").setAlpha(0.2).fit(df2).transform(df2) + .select("filtered", "preFilteredData").collect().foreach { + case Row(vec1: Vector, vec2: Vector) => + assert(vec1 ~== vec2 absTol 1e-1) + } } test("ChiSqSelector read/write") { diff --git a/mllib/src/test/scala/org/apache/spark/mllib/feature/ChiSqSelectorSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/feature/ChiSqSelectorSuite.scala index e181a544f7159..ec23a4aa7364d 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/feature/ChiSqSelectorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/feature/ChiSqSelectorSuite.scala @@ -76,7 +76,7 @@ class ChiSqSelectorSuite extends SparkFunSuite with MLlibTestSparkContext { LabeledPoint(1.0, Vectors.dense(Array(4.0))), LabeledPoint(1.0, Vectors.dense(Array(4.0))), LabeledPoint(2.0, Vectors.dense(Array(9.0)))) - val model = new ChiSqSelector().setAlpha(0.1).fit(labeledDiscreteData) + val model = new ChiSqSelector().setSelectorType("fpr").setAlpha(0.1).fit(labeledDiscreteData) val filteredData = labeledDiscreteData.map { lp => LabeledPoint(lp.label, model.transform(lp.features)) }.collect().toSet diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py index c45434f1a57ca..12a13849dc9bc 100755 --- a/python/pyspark/ml/feature.py +++ b/python/pyspark/ml/feature.py @@ -2586,39 +2586,68 @@ class ChiSqSelector(JavaEstimator, HasFeaturesCol, HasOutputCol, HasLabelCol, Ja .. versionadded:: 2.0.0 """ + selectorType = Param(Params._dummy(), "selectorType", + "The selector type of the ChisqSelector. " + + "Supported options: kbest (default), percentile and fpr.", + typeConverter=TypeConverters.toString) + numTopFeatures = \ Param(Params._dummy(), "numTopFeatures", "Number of features that selector will select, ordered by statistics value " + "descending. If the number of features is < numTopFeatures, then this will select " + "all features.", typeConverter=TypeConverters.toInt) + percentile = Param(Params._dummy(), "percentile", "Percentile of features that selector " + + "will select, ordered by statistics value descending.", + typeConverter=TypeConverters.toFloat) + + alpha = Param(Params._dummy(), "alpha", "The highest p-value for features to be kept.", + typeConverter=TypeConverters.toFloat) + @keyword_only - def __init__(self, numTopFeatures=50, featuresCol="features", outputCol=None, labelCol="label"): + def __init__(self, numTopFeatures=50, featuresCol="features", outputCol=None, + labelCol="label", selectorType="kbest", percentile=0.1, alpha=0.05): """ - __init__(self, numTopFeatures=50, featuresCol="features", outputCol=None, labelCol="label") + __init__(self, numTopFeatures=50, featuresCol="features", outputCol=None, \ + labelCol="label", selectorType="kbest", percentile=0.1, alpha=0.05) """ super(ChiSqSelector, self).__init__() self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.ChiSqSelector", self.uid) - self._setDefault(numTopFeatures=50) + self._setDefault(numTopFeatures=50, selectorType="kbest", percentile=0.1, alpha=0.05) kwargs = self.__init__._input_kwargs self.setParams(**kwargs) @keyword_only @since("2.0.0") def setParams(self, numTopFeatures=50, featuresCol="features", outputCol=None, - labelCol="labels"): + labelCol="labels", selectorType="kbest", percentile=0.1, alpha=0.05): """ - setParams(self, numTopFeatures=50, featuresCol="features", outputCol=None,\ - labelCol="labels") + setParams(self, numTopFeatures=50, featuresCol="features", outputCol=None, \ + labelCol="labels", selectorType="kbest", percentile=0.1, alpha=0.05) Sets params for this ChiSqSelector. """ kwargs = self.setParams._input_kwargs return self._set(**kwargs) + @since("2.1.0") + def setSelectorType(self, value): + """ + Sets the value of :py:attr:`selectorType`. + """ + return self._set(selectorType=value) + + @since("2.1.0") + def getSelectorType(self): + """ + Gets the value of selectorType or its default value. + """ + return self.getOrDefault(self.selectorType) + @since("2.0.0") def setNumTopFeatures(self, value): """ Sets the value of :py:attr:`numTopFeatures`. + Only applicable when selectorType = "kbest". """ return self._set(numTopFeatures=value) @@ -2629,6 +2658,36 @@ def getNumTopFeatures(self): """ return self.getOrDefault(self.numTopFeatures) + @since("2.1.0") + def setPercentile(self, value): + """ + Sets the value of :py:attr:`percentile`. + Only applicable when selectorType = "percentile". + """ + return self._set(percentile=value) + + @since("2.1.0") + def getPercentile(self): + """ + Gets the value of percentile or its default value. + """ + return self.getOrDefault(self.percentile) + + @since("2.1.0") + def setAlpha(self, value): + """ + Sets the value of :py:attr:`alpha`. + Only applicable when selectorType = "fpr". + """ + return self._set(alpha=value) + + @since("2.1.0") + def getAlpha(self): + """ + Gets the value of alpha or its default value. + """ + return self.getOrDefault(self.alpha) + def _create_model(self, java_model): return ChiSqSelectorModel(java_model) diff --git a/python/pyspark/mllib/feature.py b/python/pyspark/mllib/feature.py index 077c11370eb3f..4aea81840a162 100644 --- a/python/pyspark/mllib/feature.py +++ b/python/pyspark/mllib/feature.py @@ -271,22 +271,14 @@ def transform(self, vector): return JavaVectorTransformer.transform(self, vector) -class ChiSqSelectorType: - """ - This class defines the selector types of Chi Square Selector. - """ - KBest, Percentile, FPR = range(3) - - class ChiSqSelector(object): """ Creates a ChiSquared feature selector. The selector supports three selection methods: `KBest`, `Percentile` and `FPR`. - `KBest` chooses the `k` top features according to a chi-squared test. - `Percentile` is similar but chooses a fraction of all features instead of a fixed number. - `FPR` chooses all features whose false positive rate meets some threshold. - By default, the selection method is `KBest`, the default number of top features is 50. - User can use setNumTopFeatures, setPercentile and setAlpha to set different selection methods. + `kbest` chooses the `k` top features according to a chi-squared test. + `percentile` is similar but chooses a fraction of all features instead of a fixed number. + `fpr` chooses all features whose false positive rate meets some threshold. + By default, the selection method is `kbest`, the default number of top features is 50. >>> data = [ ... LabeledPoint(0.0, SparseVector(3, {0: 8.0, 1: 7.0})), @@ -299,7 +291,8 @@ class ChiSqSelector(object): SparseVector(1, {0: 6.0}) >>> model.transform(DenseVector([8.0, 9.0, 5.0])) DenseVector([5.0]) - >>> model = ChiSqSelector().setPercentile(0.34).fit(sc.parallelize(data)) + >>> model = ChiSqSelector().setSelectorType("percentile").setPercentile(0.34).fit( + ... sc.parallelize(data)) >>> model.transform(SparseVector(3, {1: 9.0, 2: 6.0})) SparseVector(1, {0: 6.0}) >>> model.transform(DenseVector([8.0, 9.0, 5.0])) @@ -310,41 +303,52 @@ class ChiSqSelector(object): ... LabeledPoint(1.0, [0.0, 9.0, 8.0, 4.0]), ... LabeledPoint(2.0, [8.0, 9.0, 5.0, 9.0]) ... ] - >>> model = ChiSqSelector().setAlpha(0.1).fit(sc.parallelize(data)) + >>> model = ChiSqSelector().setSelectorType("fpr").setAlpha(0.1).fit(sc.parallelize(data)) >>> model.transform(DenseVector([1.0,2.0,3.0,4.0])) DenseVector([4.0]) .. versionadded:: 1.4.0 """ - def __init__(self, numTopFeatures=50): + def __init__(self, numTopFeatures=50, selectorType="kbest", percentile=0.1, alpha=0.05): self.numTopFeatures = numTopFeatures - self.selectorType = ChiSqSelectorType.KBest + self.selectorType = selectorType + self.percentile = percentile + self.alpha = alpha @since('2.1.0') def setNumTopFeatures(self, numTopFeatures): """ - set numTopFeature for feature selection by number of top features + set numTopFeature for feature selection by number of top features. + Only applicable when selectorType = "kbest". """ self.numTopFeatures = int(numTopFeatures) - self.selectorType = ChiSqSelectorType.KBest return self @since('2.1.0') def setPercentile(self, percentile): """ - set percentile [0.0, 1.0] for feature selection by percentile + set percentile [0.0, 1.0] for feature selection by percentile. + Only applicable when selectorType = "percentile". """ self.percentile = float(percentile) - self.selectorType = ChiSqSelectorType.Percentile return self @since('2.1.0') def setAlpha(self, alpha): """ - set alpha [0.0, 1.0] for feature selection by FPR + set alpha [0.0, 1.0] for feature selection by FPR. + Only applicable when selectorType = "fpr". """ self.alpha = float(alpha) - self.selectorType = ChiSqSelectorType.FPR + return self + + @since('2.1.0') + def setSelectorType(self, selectorType): + """ + set the selector type of the ChisqSelector. + Supported options: "kbest" (default), "percentile" and "fpr". + """ + self.selectorType = str(selectorType) return self @since('1.4.0') @@ -357,15 +361,8 @@ def fit(self, data): treated as categorical for each distinct value. Apply feature discretizer before using this function. """ - if self.selectorType == ChiSqSelectorType.KBest: - jmodel = callMLlibFunc("fitChiSqSelectorKBest", self.numTopFeatures, data) - elif self.selectorType == ChiSqSelectorType.Percentile: - jmodel = callMLlibFunc("fitChiSqSelectorPercentile", self.percentile, data) - elif self.selectorType == ChiSqSelectorType.FPR: - jmodel = callMLlibFunc("fitChiSqSelectorFPR", self.alpha, data) - else: - raise ValueError("ChiSqSelector type supports KBest(0), Percentile(1) and" - " FPR(2), the current value is: %s" % self.selectorType) + jmodel = callMLlibFunc("fitChiSqSelector", self.selectorType, self.numTopFeatures, + self.percentile, self.alpha, data) return ChiSqSelectorModel(jmodel)