diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Canonicalize.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Canonicalize.scala index 4d218b936b3a2..a8031086d82f7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Canonicalize.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Canonicalize.scala @@ -27,6 +27,7 @@ package org.apache.spark.sql.catalyst.expressions * The following rules are applied: * - Names and nullability hints for [[org.apache.spark.sql.types.DataType]]s are stripped. * - Names for [[GetStructField]] are stripped. + * - TimeZoneId for [[Cast]] and [[AnsiCast]] are stripped if `needsTimeZone` is false. * - Commutative and associative operations ([[Add]] and [[Multiply]]) have their children ordered * by `hashCode`. * - [[EqualTo]] and [[EqualNullSafe]] are reordered by `hashCode`. @@ -35,7 +36,7 @@ package org.apache.spark.sql.catalyst.expressions */ object Canonicalize { def execute(e: Expression): Expression = { - expressionReorder(ignoreNamesTypes(e)) + expressionReorder(ignoreTimeZone(ignoreNamesTypes(e))) } /** Remove names and nullability from types, and names from `GetStructField`. */ @@ -46,6 +47,13 @@ object Canonicalize { case _ => e } + /** Remove TimeZoneId for Cast if needsTimeZone return false. */ + private[expressions] def ignoreTimeZone(e: Expression): Expression = e match { + case c: CastBase if c.timeZoneId.nonEmpty && !c.needsTimeZone => + c.withTimeZone(null) + case _ => e + } + /** Collects adjacent commutative operations. */ private def gatherCommutative( e: Expression, diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index 8d82956cc6f74..fa615d71a61a0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -279,7 +279,7 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit override lazy val resolved: Boolean = childrenResolved && checkInputDataTypes().isSuccess && (!needsTimeZone || timeZoneId.isDefined) - private[this] def needsTimeZone: Boolean = Cast.needsTimeZone(child.dataType, dataType) + def needsTimeZone: Boolean = Cast.needsTimeZone(child.dataType, dataType) // [[func]] assumes the input is no longer null because eval already does the null check. @inline private[this] def buildCast[T](a: Any, func: T => Any): Any = func(a.asInstanceOf[T]) @@ -1708,6 +1708,7 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit """) case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String] = None) extends CastBase { + override def withTimeZone(timeZoneId: String): TimeZoneAwareExpression = copy(timeZoneId = Option(timeZoneId)) @@ -1724,6 +1725,7 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String */ case class AnsiCast(child: Expression, dataType: DataType, timeZoneId: Option[String] = None) extends CastBase { + override def withTimeZone(timeZoneId: String): TimeZoneAwareExpression = copy(timeZoneId = Option(timeZoneId)) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CanonicalizeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CanonicalizeSuite.scala index dd719437d618d..a043b4cbed1f1 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CanonicalizeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CanonicalizeSuite.scala @@ -17,10 +17,12 @@ package org.apache.spark.sql.catalyst.expressions +import java.util.TimeZone + import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.plans.logical.Range -import org.apache.spark.sql.types.{IntegerType, StructField, StructType} +import org.apache.spark.sql.types.{IntegerType, LongType, StructField, StructType} class CanonicalizeSuite extends SparkFunSuite { @@ -86,4 +88,11 @@ class CanonicalizeSuite extends SparkFunSuite { val subExpr = Subtract(range.output.head, Literal(1)) assert(addExpr.canonicalized.hashCode() != subExpr.canonicalized.hashCode()) } + + test("SPARK-31515: Canonicalize Cast should consider the value of needTimeZone") { + val literal = Literal(1) + val cast = Cast(literal, LongType) + val castWithTimeZoneId = Cast(literal, LongType, Some(TimeZone.getDefault.getID)) + assert(castWithTimeZoneId.semanticEquals(cast)) + } }