diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala index 184c5a11298d9..28820681cd3a6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala @@ -128,6 +128,12 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) { /** * Returns a new `DataFrame` that replaces null or NaN values in numeric columns with `value`. * + * @since 2.2.0 + */ + def fill(value: Long): DataFrame = fill(value, df.columns) + + /** + * Returns a new `DataFrame` that replaces null or NaN values in numeric columns with `value`. * @since 1.3.1 */ def fill(value: Double): DataFrame = fill(value, df.columns) @@ -139,6 +145,14 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) { */ def fill(value: String): DataFrame = fill(value, df.columns) + /** + * Returns a new `DataFrame` that replaces null or NaN values in specified numeric columns. + * If a specified column is not a numeric column, it is ignored. + * + * @since 2.2.0 + */ + def fill(value: Long, cols: Array[String]): DataFrame = fill(value, cols.toSeq) + /** * Returns a new `DataFrame` that replaces null or NaN values in specified numeric columns. * If a specified column is not a numeric column, it is ignored. @@ -147,24 +161,22 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) { */ def fill(value: Double, cols: Array[String]): DataFrame = fill(value, cols.toSeq) + /** + * (Scala-specific) Returns a new `DataFrame` that replaces null or NaN values in specified + * numeric columns. If a specified column is not a numeric column, it is ignored. + * + * @since 2.2.0 + */ + def fill(value: Long, cols: Seq[String]): DataFrame = fillValue(value, cols) + /** * (Scala-specific) Returns a new `DataFrame` that replaces null or NaN values in specified * numeric columns. If a specified column is not a numeric column, it is ignored. * * @since 1.3.1 */ - def fill(value: Double, cols: Seq[String]): DataFrame = { - val columnEquals = df.sparkSession.sessionState.analyzer.resolver - val projections = df.schema.fields.map { f => - // Only fill if the column is part of the cols list. - if (f.dataType.isInstanceOf[NumericType] && cols.exists(col => columnEquals(f.name, col))) { - fillCol[Double](f, value) - } else { - df.col(f.name) - } - } - df.select(projections : _*) - } + def fill(value: Double, cols: Seq[String]): DataFrame = fillValue(value, cols) + /** * Returns a new `DataFrame` that replaces null values in specified string columns. @@ -180,18 +192,7 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) { * * @since 1.3.1 */ - def fill(value: String, cols: Seq[String]): DataFrame = { - val columnEquals = df.sparkSession.sessionState.analyzer.resolver - val projections = df.schema.fields.map { f => - // Only fill if the column is part of the cols list. - if (f.dataType.isInstanceOf[StringType] && cols.exists(col => columnEquals(f.name, col))) { - fillCol[String](f, value) - } else { - df.col(f.name) - } - } - df.select(projections : _*) - } + def fill(value: String, cols: Seq[String]): DataFrame = fillValue(value, cols) /** * Returns a new `DataFrame` that replaces null values. @@ -210,7 +211,7 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) { * * @since 1.3.1 */ - def fill(valueMap: java.util.Map[String, Any]): DataFrame = fill0(valueMap.asScala.toSeq) + def fill(valueMap: java.util.Map[String, Any]): DataFrame = fillMap(valueMap.asScala.toSeq) /** * (Scala-specific) Returns a new `DataFrame` that replaces null values. @@ -230,7 +231,7 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) { * * @since 1.3.1 */ - def fill(valueMap: Map[String, Any]): DataFrame = fill0(valueMap.toSeq) + def fill(valueMap: Map[String, Any]): DataFrame = fillMap(valueMap.toSeq) /** * Replaces values matching keys in `replacement` map with the corresponding values. @@ -368,7 +369,7 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) { df.select(projections : _*) } - private def fill0(values: Seq[(String, Any)]): DataFrame = { + private def fillMap(values: Seq[(String, Any)]): DataFrame = { // Error handling values.foreach { case (colName, replaceValue) => // Check column name exists @@ -435,4 +436,38 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) { case v => throw new IllegalArgumentException( s"Unsupported value type ${v.getClass.getName} ($v).") } + + /** + * Returns a new `DataFrame` that replaces null or NaN values in specified + * numeric, string columns. If a specified column is not a numeric, string column, + * it is ignored. + */ + private def fillValue[T](value: T, cols: Seq[String]): DataFrame = { + // the fill[T] which T is Long/Double, + // should apply on all the NumericType Column, for example: + // val input = Seq[(java.lang.Integer, java.lang.Double)]((null, 164.3)).toDF("a","b") + // input.na.fill(3.1) + // the result is (3,164.3), not (null, 164.3) + val targetType = value match { + case _: Double | _: Long => NumericType + case _: String => StringType + case _ => throw new IllegalArgumentException( + s"Unsupported value type ${value.getClass.getName} ($value).") + } + + val columnEquals = df.sparkSession.sessionState.analyzer.resolver + val projections = df.schema.fields.map { f => + val typeMatches = (targetType, f.dataType) match { + case (NumericType, dt) => dt.isInstanceOf[NumericType] + case (StringType, dt) => dt == StringType + } + // Only fill if the column is part of the cols list. + if (typeMatches && cols.exists(col => columnEquals(f.name, col))) { + fillCol[T](f, value) + } else { + df.col(f.name) + } + } + df.select(projections : _*) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala index 47b55e2547d19..fd829846ac332 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala @@ -138,6 +138,24 @@ class DataFrameNaFunctionsSuite extends QueryTest with SharedSQLContext { checkAnswer( Seq[(String, String)]((null, null)).toDF("col1", "col2").na.fill("test", "col1" :: Nil), Row("test", null)) + + checkAnswer( + Seq[(Long, Long)]((1, 2), (-1, -2), (9123146099426677101L, 9123146560113991650L)) + .toDF("a", "b").na.fill(0), + Row(1, 2) :: Row(-1, -2) :: Row(9123146099426677101L, 9123146560113991650L) :: Nil + ) + + checkAnswer( + Seq[(java.lang.Long, java.lang.Double)]((null, 1.23), (3L, null), (4L, 3.45)) + .toDF("a", "b").na.fill(2.34), + Row(2, 1.23) :: Row(3, 2.34) :: Row(4, 3.45) :: Nil + ) + + checkAnswer( + Seq[(java.lang.Long, java.lang.Double)]((null, 1.23), (3L, null), (4L, 3.45)) + .toDF("a", "b").na.fill(5), + Row(5, 1.23) :: Row(3, 5.0) :: Row(4, 3.45) :: Nil + ) } test("fill with map") {