diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala index 768897dc0713c..e1dd010d37a95 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala @@ -571,6 +571,7 @@ object TypeCoercion { NaNvl(l, Cast(r, DoubleType)) case NaNvl(l, r) if l.dataType == FloatType && r.dataType == DoubleType => NaNvl(Cast(l, DoubleType), r) + case NaNvl(l, r) if r.dataType == NullType => NaNvl(l, Cast(r, l.dataType)) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala index 3e0c357b6de42..011d09ff60641 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala @@ -656,14 +656,20 @@ class TypeCoercionSuite extends PlanTest { test("nanvl casts") { ruleTest(TypeCoercion.FunctionArgumentConversion, - NaNvl(Literal.create(1.0, FloatType), Literal.create(1.0, DoubleType)), - NaNvl(Cast(Literal.create(1.0, FloatType), DoubleType), Literal.create(1.0, DoubleType))) + NaNvl(Literal.create(1.0f, FloatType), Literal.create(1.0, DoubleType)), + NaNvl(Cast(Literal.create(1.0f, FloatType), DoubleType), Literal.create(1.0, DoubleType))) ruleTest(TypeCoercion.FunctionArgumentConversion, - NaNvl(Literal.create(1.0, DoubleType), Literal.create(1.0, FloatType)), - NaNvl(Literal.create(1.0, DoubleType), Cast(Literal.create(1.0, FloatType), DoubleType))) + NaNvl(Literal.create(1.0, DoubleType), Literal.create(1.0f, FloatType)), + NaNvl(Literal.create(1.0, DoubleType), Cast(Literal.create(1.0f, FloatType), DoubleType))) ruleTest(TypeCoercion.FunctionArgumentConversion, NaNvl(Literal.create(1.0, DoubleType), Literal.create(1.0, DoubleType)), NaNvl(Literal.create(1.0, DoubleType), Literal.create(1.0, DoubleType))) + ruleTest(TypeCoercion.FunctionArgumentConversion, + NaNvl(Literal.create(1.0f, FloatType), Literal.create(null, NullType)), + NaNvl(Literal.create(1.0f, FloatType), Cast(Literal.create(null, NullType), FloatType))) + ruleTest(TypeCoercion.FunctionArgumentConversion, + NaNvl(Literal.create(1.0, DoubleType), Literal.create(null, NullType)), + NaNvl(Literal.create(1.0, DoubleType), Cast(Literal.create(null, NullType), DoubleType))) } test("type coercion for If") { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala index 93d565d9fe904..052d85ad33bd6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala @@ -408,8 +408,7 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) { val quotedColName = "`" + col.name + "`" val colValue = col.dataType match { case DoubleType | FloatType => - // nanvl only supports these types - nanvl(df.col(quotedColName), lit(null).cast(col.dataType)) + nanvl(df.col(quotedColName), lit(null)) // nanvl only supports these types case _ => df.col(quotedColName) } coalesce(colValue, lit(replacement).cast(col.dataType)).as(col.name)