diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index d38839daf182d..573fe047eb589 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -134,15 +134,13 @@ object NullPropagation extends Rule[LogicalPlan] { case Literal(null, _) => Literal(null, e.dataType) case _ => e } - case e: And => e // leave it for BooleanSimplification - case e: Or => e // leave it for BooleanSimplification - // Put exceptional cases above + // Put exceptional cases above if any case e: BinaryArithmetic => e.children match { case Literal(null, _) :: right :: Nil => Literal(null, e.dataType) case left :: Literal(null, _) :: Nil => Literal(null, e.dataType) case _ => e } - case e: BinaryPredicate => e.children match { + case e: BinaryComparison => e.children match { case Literal(null, _) :: right :: Nil => Literal(null, e.dataType) case left :: Literal(null, _) :: Nil => Literal(null, e.dataType) case _ => e diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala index cc6df1088e737..91605d0a260e5 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala @@ -250,17 +250,17 @@ class ExpressionEvaluationSuite extends FunSuite { intercept[Exception] {evaluate(Literal(1) cast BinaryType, null)} - checkEvaluation(("abcdef" cast StringType).nullable, false) - checkEvaluation(("abcdef" cast BinaryType).nullable,false) - checkEvaluation(("abcdef" cast BooleanType).nullable, false) - checkEvaluation(("abcdef" cast TimestampType).nullable, true) - checkEvaluation(("abcdef" cast LongType).nullable, true) - checkEvaluation(("abcdef" cast IntegerType).nullable, true) - checkEvaluation(("abcdef" cast ShortType).nullable, true) - checkEvaluation(("abcdef" cast ByteType).nullable, true) - checkEvaluation(("abcdef" cast DecimalType).nullable, true) - checkEvaluation(("abcdef" cast DoubleType).nullable, true) - checkEvaluation(("abcdef" cast FloatType).nullable, true) + assert(("abcdef" cast StringType).nullable === false) + assert(("abcdef" cast BinaryType).nullable === false) + assert(("abcdef" cast BooleanType).nullable === false) + assert(("abcdef" cast TimestampType).nullable === true) + assert(("abcdef" cast LongType).nullable === true) + assert(("abcdef" cast IntegerType).nullable === true) + assert(("abcdef" cast ShortType).nullable === true) + assert(("abcdef" cast ByteType).nullable === true) + assert(("abcdef" cast DecimalType).nullable === true) + assert(("abcdef" cast DoubleType).nullable === true) + assert(("abcdef" cast FloatType).nullable === true) checkEvaluation(Cast(Literal(null, IntegerType), ShortType), null) }