diff --git a/docs/ml-features.md b/docs/ml-features.md index a7f710fa52e64..64c6a160239cc 100644 --- a/docs/ml-features.md +++ b/docs/ml-features.md @@ -1103,11 +1103,16 @@ for more details on the API. `QuantileDiscretizer` takes a column with continuous features and outputs a column with binned categorical features. The number of bins is set by the `numBuckets` parameter. It is possible -that the number of buckets used will be less than this value, for example, if there are too few -distinct values of the input to create enough distinct quantiles. Note also that NaN values are -handled specially and placed into their own bucket. For example, if 4 buckets are used, then -non-NaN data will be put into buckets[0-3], but NaNs will be counted in a special bucket[4]. -The bin ranges are chosen using an approximate algorithm (see the documentation for +that the number of buckets used will be smaller than this value, for example, if there are too few +distinct values of the input to create enough distinct quantiles. + +NaN values: Note also that QuantileDiscretizer +will raise an error when it finds NaN values in the dataset, but the user can also choose to either +keep or remove NaN values within the dataset by setting `handleInvalid`. If the user chooses to keep +NaN values, they will be handled specially and placed into their own bucket, for example, if 4 buckets +are used, then non-NaN data will be put into buckets[0-3], but NaNs will be counted in a special bucket[4]. + +Algorithm: The bin ranges are chosen using an approximate algorithm (see the documentation for [approxQuantile](api/scala/index.html#org.apache.spark.sql.DataFrameStatFunctions) for a detailed description). The precision of the approximation can be controlled with the `relativeError` parameter. When set to zero, exact quantiles are calculated diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala index ec0ea05f9e1b1..1143f0f565ebd 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala @@ -27,6 +27,7 @@ import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol} import org.apache.spark.ml.util._ import org.apache.spark.sql._ +import org.apache.spark.sql.expressions.UserDefinedFunction import org.apache.spark.sql.functions._ import org.apache.spark.sql.types.{DoubleType, StructField, StructType} @@ -46,6 +47,9 @@ final class Bucketizer @Since("1.4.0") (@Since("1.4.0") override val uid: String * also includes y. Splits should be of length >= 3 and strictly increasing. * Values at -inf, inf must be explicitly provided to cover all Double values; * otherwise, values outside the splits specified will be treated as errors. + * + * See also [[handleInvalid]], which can optionally create an additional bucket for NaN values. + * * @group param */ @Since("1.4.0") @@ -73,15 +77,47 @@ final class Bucketizer @Since("1.4.0") (@Since("1.4.0") override val uid: String @Since("1.4.0") def setOutputCol(value: String): this.type = set(outputCol, value) + /** + * Param for how to handle invalid entries. Options are skip (filter out rows with + * invalid values), error (throw an error), or keep (keep invalid values in a special additional + * bucket). + * Default: "error" + * @group param + */ + @Since("2.1.0") + val handleInvalid: Param[String] = new Param[String](this, "handleInvalid", "how to handle" + + "invalid entries. Options are skip (filter out rows with invalid values), " + + "error (throw an error), or keep (keep invalid values in a special additional bucket).", + ParamValidators.inArray(Bucketizer.supportedHandleInvalid)) + + /** @group getParam */ + @Since("2.1.0") + def getHandleInvalid: String = $(handleInvalid) + + /** @group setParam */ + @Since("2.1.0") + def setHandleInvalid(value: String): this.type = set(handleInvalid, value) + setDefault(handleInvalid, Bucketizer.ERROR_INVALID) + @Since("2.0.0") override def transform(dataset: Dataset[_]): DataFrame = { transformSchema(dataset.schema) - val bucketizer = udf { feature: Double => - Bucketizer.binarySearchForBuckets($(splits), feature) + val (filteredDataset, keepInvalid) = { + if (getHandleInvalid == Bucketizer.SKIP_INVALID) { + // "skip" NaN option is set, will filter out NaN values in the dataset + (dataset.na.drop().toDF(), false) + } else { + (dataset.toDF(), getHandleInvalid == Bucketizer.KEEP_INVALID) + } + } + + val bucketizer: UserDefinedFunction = udf { (feature: Double) => + Bucketizer.binarySearchForBuckets($(splits), feature, keepInvalid) } - val newCol = bucketizer(dataset($(inputCol))) - val newField = prepOutputField(dataset.schema) - dataset.withColumn($(outputCol), newCol, newField.metadata) + + val newCol = bucketizer(filteredDataset($(inputCol))) + val newField = prepOutputField(filteredDataset.schema) + filteredDataset.withColumn($(outputCol), newCol, newField.metadata) } private def prepOutputField(schema: StructType): StructField = { @@ -106,6 +142,12 @@ final class Bucketizer @Since("1.4.0") (@Since("1.4.0") override val uid: String @Since("1.6.0") object Bucketizer extends DefaultParamsReadable[Bucketizer] { + private[feature] val SKIP_INVALID: String = "skip" + private[feature] val ERROR_INVALID: String = "error" + private[feature] val KEEP_INVALID: String = "keep" + private[feature] val supportedHandleInvalid: Array[String] = + Array(SKIP_INVALID, ERROR_INVALID, KEEP_INVALID) + /** * We require splits to be of length >= 3 and to be in strictly increasing order. * No NaN split should be accepted. @@ -126,11 +168,26 @@ object Bucketizer extends DefaultParamsReadable[Bucketizer] { /** * Binary searching in several buckets to place each data point. + * @param splits array of split points + * @param feature data point + * @param keepInvalid NaN flag. + * Set "true" to make an extra bucket for NaN values; + * Set "false" to report an error for NaN values + * @return bucket for each data point * @throws SparkException if a feature is < splits.head or > splits.last */ - private[feature] def binarySearchForBuckets(splits: Array[Double], feature: Double): Double = { + + private[feature] def binarySearchForBuckets( + splits: Array[Double], + feature: Double, + keepInvalid: Boolean): Double = { if (feature.isNaN) { - splits.length - 1 + if (keepInvalid) { + splits.length - 1 + } else { + throw new SparkException("Bucketizer encountered NaN value. To handle or skip NaNs," + + " try setting Bucketizer.handleInvalid.") + } } else if (feature == splits.last) { splits.length - 2 } else { diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala index 05e034d90f6a3..b9e01dde70d85 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala @@ -36,6 +36,9 @@ private[feature] trait QuantileDiscretizerBase extends Params /** * Number of buckets (quantiles, or categories) into which data points are grouped. Must * be >= 2. + * + * See also [[handleInvalid]], which can optionally create an additional bucket for NaN values. + * * default: 2 * @group param */ @@ -61,17 +64,41 @@ private[feature] trait QuantileDiscretizerBase extends Params /** @group getParam */ def getRelativeError: Double = getOrDefault(relativeError) + + /** + * Param for how to handle invalid entries. Options are skip (filter out rows with + * invalid values), error (throw an error), or keep (keep invalid values in a special additional + * bucket). + * Default: "error" + * @group param + */ + @Since("2.1.0") + val handleInvalid: Param[String] = new Param[String](this, "handleInvalid", "how to handle" + + "invalid entries. Options are skip (filter out rows with invalid values), " + + "error (throw an error), or keep (keep invalid values in a special additional bucket).", + ParamValidators.inArray(Bucketizer.supportedHandleInvalid)) + setDefault(handleInvalid, Bucketizer.ERROR_INVALID) + + /** @group getParam */ + @Since("2.1.0") + def getHandleInvalid: String = $(handleInvalid) + } /** * `QuantileDiscretizer` takes a column with continuous features and outputs a column with binned * categorical features. The number of bins can be set using the `numBuckets` parameter. It is - * possible that the number of buckets used will be less than this value, for example, if there - * are too few distinct values of the input to create enough distinct quantiles. Note also that - * NaN values are handled specially and placed into their own bucket. For example, if 4 buckets - * are used, then non-NaN data will be put into buckets(0-3), but NaNs will be counted in a special - * bucket(4). - * The bin ranges are chosen using an approximate algorithm (see the documentation for + * possible that the number of buckets used will be smaller than this value, for example, if there + * are too few distinct values of the input to create enough distinct quantiles. + * + * NaN handling: Note also that + * QuantileDiscretizer will raise an error when it finds NaN values in the dataset, but the user can + * also choose to either keep or remove NaN values within the dataset by setting `handleInvalid`. + * If the user chooses to keep NaN values, they will be handled specially and placed into their own + * bucket, for example, if 4 buckets are used, then non-NaN data will be put into buckets[0-3], + * but NaNs will be counted in a special bucket[4]. + * + * Algorithm: The bin ranges are chosen using an approximate algorithm (see the documentation for * [[org.apache.spark.sql.DataFrameStatFunctions.approxQuantile approxQuantile]] * for a detailed description). The precision of the approximation can be controlled with the * `relativeError` parameter. The lower and upper bin bounds will be `-Infinity` and `+Infinity`, @@ -100,6 +127,10 @@ final class QuantileDiscretizer @Since("1.6.0") (@Since("1.6.0") override val ui @Since("1.6.0") def setOutputCol(value: String): this.type = set(outputCol, value) + /** @group setParam */ + @Since("2.1.0") + def setHandleInvalid(value: String): this.type = set(handleInvalid, value) + @Since("1.6.0") override def transformSchema(schema: StructType): StructType = { SchemaUtils.checkNumericType(schema, $(inputCol)) @@ -124,7 +155,9 @@ final class QuantileDiscretizer @Since("1.6.0") (@Since("1.6.0") override val ui log.warn(s"Some quantiles were identical. Bucketing to ${distinctSplits.length - 1}" + s" buckets as a result.") } - val bucketizer = new Bucketizer(uid).setSplits(distinctSplits.sorted) + val bucketizer = new Bucketizer(uid) + .setSplits(distinctSplits.sorted) + .setHandleInvalid($(handleInvalid)) copyValues(bucketizer.setParent(this)) } diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala index 87cdceb267387..aac29137d7911 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala @@ -99,21 +99,32 @@ class BucketizerSuite extends SparkFunSuite with MLlibTestSparkContext with Defa .setOutputCol("result") .setSplits(splits) + bucketizer.setHandleInvalid("keep") bucketizer.transform(dataFrame).select("result", "expected").collect().foreach { case Row(x: Double, y: Double) => assert(x === y, s"The feature value is not correct after bucketing. Expected $y but found $x") } + + bucketizer.setHandleInvalid("skip") + val skipResults: Array[Double] = bucketizer.transform(dataFrame) + .select("result").as[Double].collect() + assert(skipResults.length === 7) + assert(skipResults.forall(_ !== 4.0)) + + bucketizer.setHandleInvalid("error") + withClue("Bucketizer should throw error when setHandleInvalid=error and given NaN values") { + intercept[SparkException] { + bucketizer.transform(dataFrame).collect() + } + } } test("Bucket continuous features, with NaN splits") { val splits = Array(Double.NegativeInfinity, -0.5, 0.0, 0.5, Double.PositiveInfinity, Double.NaN) - withClue("Invalid NaN split was not caught as an invalid split!") { + withClue("Invalid NaN split was not caught during Bucketizer initialization") { intercept[IllegalArgumentException] { - val bucketizer: Bucketizer = new Bucketizer() - .setInputCol("feature") - .setOutputCol("result") - .setSplits(splits) + new Bucketizer().setSplits(splits) } } } @@ -138,7 +149,8 @@ class BucketizerSuite extends SparkFunSuite with MLlibTestSparkContext with Defa val data = Array.fill(100)(Random.nextDouble()) val splits: Array[Double] = Double.NegativeInfinity +: Array.fill(10)(Random.nextDouble()).sorted :+ Double.PositiveInfinity - val bsResult = Vectors.dense(data.map(x => Bucketizer.binarySearchForBuckets(splits, x))) + val bsResult = Vectors.dense(data.map(x => + Bucketizer.binarySearchForBuckets(splits, x, false))) val lsResult = Vectors.dense(data.map(x => BucketizerSuite.linearSearchForBuckets(splits, x))) assert(bsResult ~== lsResult absTol 1e-5) } @@ -169,7 +181,7 @@ private object BucketizerSuite extends SparkFunSuite { /** Check all values in splits, plus values between all splits. */ def checkBinarySearch(splits: Array[Double]): Unit = { def testFeature(feature: Double, expectedBucket: Double): Unit = { - assert(Bucketizer.binarySearchForBuckets(splits, feature) === expectedBucket, + assert(Bucketizer.binarySearchForBuckets(splits, feature, false) === expectedBucket, s"Expected feature value $feature to be in bucket $expectedBucket with splits:" + s" ${splits.mkString(", ")}") } diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/QuantileDiscretizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/QuantileDiscretizerSuite.scala index 6822594044a56..f219f775b2186 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/QuantileDiscretizerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/QuantileDiscretizerSuite.scala @@ -17,10 +17,10 @@ package org.apache.spark.ml.feature -import org.apache.spark.SparkFunSuite +import org.apache.spark.{SparkException, SparkFunSuite} import org.apache.spark.ml.util.DefaultReadWriteTest import org.apache.spark.mllib.util.MLlibTestSparkContext -import org.apache.spark.sql.SparkSession +import org.apache.spark.sql._ import org.apache.spark.sql.functions.udf class QuantileDiscretizerSuite @@ -76,20 +76,33 @@ class QuantileDiscretizerSuite import spark.implicits._ val numBuckets = 3 - val df = sc.parallelize(Array(1.0, 1.0, 1.0, Double.NaN)) - .map(Tuple1.apply).toDF("input") + val validData = Array(-0.9, -0.5, -0.3, 0.0, 0.2, 0.5, 0.9, Double.NaN, Double.NaN, Double.NaN) + val expectedKeep = Array(0.0, 0.0, 1.0, 1.0, 2.0, 2.0, 2.0, 3.0, 3.0, 3.0) + val expectedSkip = Array(0.0, 0.0, 1.0, 1.0, 2.0, 2.0, 2.0) + val discretizer = new QuantileDiscretizer() .setInputCol("input") .setOutputCol("result") .setNumBuckets(numBuckets) - // Reserve extra one bucket for NaN - val expectedNumBuckets = discretizer.fit(df).getSplits.length - 1 - val result = discretizer.fit(df).transform(df) - val observedNumBuckets = result.select("result").distinct.count - assert(observedNumBuckets == expectedNumBuckets, - s"Observed number of buckets are not correct." + - s" Expected $expectedNumBuckets but found $observedNumBuckets") + withClue("QuantileDiscretizer with handleInvalid=error should throw exception for NaN values") { + val dataFrame: DataFrame = validData.toSeq.toDF("input") + intercept[SparkException] { + discretizer.fit(dataFrame).transform(dataFrame).collect() + } + } + + List(("keep", expectedKeep), ("skip", expectedSkip)).foreach{ + case(u, v) => + discretizer.setHandleInvalid(u) + val dataFrame: DataFrame = validData.zip(v).toSeq.toDF("input", "expected") + val result = discretizer.fit(dataFrame).transform(dataFrame) + result.select("result", "expected").collect().foreach { + case Row(x: Double, y: Double) => + assert(x === y, + s"The feature value is not correct after bucketing. Expected $y but found $x") + } + } } test("Test transform method on unseen data") { diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py index 7683360664ebd..94afe82a36472 100755 --- a/python/pyspark/ml/feature.py +++ b/python/pyspark/ml/feature.py @@ -1155,11 +1155,6 @@ class QuantileDiscretizer(JavaEstimator, HasInputCol, HasOutputCol, JavaMLReadab `QuantileDiscretizer` takes a column with continuous features and outputs a column with binned categorical features. The number of bins can be set using the :py:attr:`numBuckets` parameter. - It is possible that the number of buckets used will be less than this value, for example, if - there are too few distinct values of the input to create enough distinct quantiles. Note also - that NaN values are handled specially and placed into their own bucket. For example, if 4 - buckets are used, then non-NaN data will be put into buckets(0-3), but NaNs will be counted in - a special bucket(4). The bin ranges are chosen using an approximate algorithm (see the documentation for :py:meth:`~.DataFrameStatFunctions.approxQuantile` for a detailed description). The precision of the approximation can be controlled with the diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala index 73026c749db45..1383208874a19 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala @@ -150,6 +150,10 @@ class DataFrameStatSuite extends QueryTest with SharedSQLContext { assert(math.abs(d1 - 2 * q1 * n) < error_double) assert(math.abs(d2 - 2 * q2 * n) < error_double) } + // test approxQuantile on NaN values + val dfNaN = Seq(Double.NaN, 1.0, Double.NaN, Double.NaN).toDF("input") + val resNaN = dfNaN.stat.approxQuantile("input", Array(q1, q2), epsilons.head) + assert(resNaN.count(_.isNaN) === 0) } test("crosstab") {