Skip to content

Commit

Permalink
Change the implementation to add ignorTimeZone in Canonicalize
Browse files Browse the repository at this point in the history
  • Loading branch information
xuanyuanking committed Apr 22, 2020
1 parent c23b2ed commit 95102cc
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 13 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ package org.apache.spark.sql.catalyst.expressions
* The following rules are applied:
* - Names and nullability hints for [[org.apache.spark.sql.types.DataType]]s are stripped.
* - Names for [[GetStructField]] are stripped.
* - TimeZoneId for [[Cast]] and [[AnsiCast]] are stripped if `needsTimeZone` is false.
* - Commutative and associative operations ([[Add]] and [[Multiply]]) have their children ordered
* by `hashCode`.
* - [[EqualTo]] and [[EqualNullSafe]] are reordered by `hashCode`.
Expand All @@ -35,7 +36,7 @@ package org.apache.spark.sql.catalyst.expressions
*/
object Canonicalize {
def execute(e: Expression): Expression = {
expressionReorder(ignoreNamesTypes(e))
expressionReorder(ignoreTimeZone(ignoreNamesTypes(e)))
}

/** Remove names and nullability from types, and names from `GetStructField`. */
Expand All @@ -46,6 +47,15 @@ object Canonicalize {
case _ => e
}

/** Remove TimeZoneId for Cast if needsTimeZone return false. */
private[expressions] def ignoreTimeZone(e: Expression): Expression = e match {
case a: AnsiCast if !a.needsTimeZone && a.timeZoneId.nonEmpty =>
a.copy(timeZoneId = None)
case c: Cast if !c.needsTimeZone && c.timeZoneId.nonEmpty =>
c.copy(timeZoneId = None)
case _ => e
}

/** Collects adjacent commutative operations. */
private def gatherCommutative(
e: Expression,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -279,7 +279,7 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit
override lazy val resolved: Boolean =
childrenResolved && checkInputDataTypes().isSuccess && (!needsTimeZone || timeZoneId.isDefined)

protected[this] def needsTimeZone: Boolean = Cast.needsTimeZone(child.dataType, dataType)
def needsTimeZone: Boolean = Cast.needsTimeZone(child.dataType, dataType)

// [[func]] assumes the input is no longer null because eval already does the null check.
@inline private[this] def buildCast[T](a: Any, func: T => Any): Any = func(a.asInstanceOf[T])
Expand Down Expand Up @@ -1708,11 +1708,6 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit
""")
case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String] = None)
extends CastBase {
override lazy val canonicalized: Expression = if (!needsTimeZone && timeZoneId.nonEmpty) {
copy(timeZoneId = None)
} else {
this
}

override def withTimeZone(timeZoneId: String): TimeZoneAwareExpression =
copy(timeZoneId = Option(timeZoneId))
Expand All @@ -1730,11 +1725,6 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String
*/
case class AnsiCast(child: Expression, dataType: DataType, timeZoneId: Option[String] = None)
extends CastBase {
override lazy val canonicalized: Expression = if (!needsTimeZone && timeZoneId.nonEmpty) {
copy(timeZoneId = None)
} else {
this
}

override def withTimeZone(timeZoneId: String): TimeZoneAwareExpression =
copy(timeZoneId = Option(timeZoneId))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,6 @@ class CanonicalizeSuite extends SparkFunSuite {
val literal = Literal(1)
val cast = Cast(literal, LongType)
val castWithTimeZoneId = Cast(literal, LongType, Some(TimeZone.getDefault.getID))

assert(castWithTimeZoneId.semanticEquals(cast))
}
}

0 comments on commit 95102cc

Please sign in to comment.