Skip to content

Commit

Permalink
[SPARK-31515][SQL] Canonicalize Cast should consider the value of nee…
Browse files Browse the repository at this point in the history
…dTimeZone

### What changes were proposed in this pull request?
Override the canonicalized fields with respect to the result of `needsTimeZone`.

### Why are the changes needed?
The current approach breaks sematic equal of two cast expressions that don't relate with datetime type. If we don't need to use `timeZone` information casting `from` type to `to` type, then the timeZoneId should not influence the canonicalize result.

### Does this PR introduce any user-facing change?
No.

### How was this patch tested?
New UT added.

Closes #28288 from xuanyuanking/SPARK-31515.

Authored-by: Yuanjian Li <xyliyuanjian@gmail.com>
Signed-off-by: Takeshi Yamamuro <yamamuro@apache.org>
(cherry picked from commit ca90e19)
Signed-off-by: Takeshi Yamamuro <yamamuro@apache.org>
  • Loading branch information
xuanyuanking authored and maropu committed Apr 23, 2020
1 parent ed3e4bd commit 2ebef75
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ package org.apache.spark.sql.catalyst.expressions
* The following rules are applied:
* - Names and nullability hints for [[org.apache.spark.sql.types.DataType]]s are stripped.
* - Names for [[GetStructField]] are stripped.
* - TimeZoneId for [[Cast]] and [[AnsiCast]] are stripped if `needsTimeZone` is false.
* - Commutative and associative operations ([[Add]] and [[Multiply]]) have their children ordered
* by `hashCode`.
* - [[EqualTo]] and [[EqualNullSafe]] are reordered by `hashCode`.
Expand All @@ -35,7 +36,7 @@ package org.apache.spark.sql.catalyst.expressions
*/
object Canonicalize {
def execute(e: Expression): Expression = {
expressionReorder(ignoreNamesTypes(e))
expressionReorder(ignoreTimeZone(ignoreNamesTypes(e)))
}

/** Remove names and nullability from types, and names from `GetStructField`. */
Expand All @@ -46,6 +47,13 @@ object Canonicalize {
case _ => e
}

/** Remove TimeZoneId for Cast if needsTimeZone return false. */
private[expressions] def ignoreTimeZone(e: Expression): Expression = e match {
case c: CastBase if c.timeZoneId.nonEmpty && !c.needsTimeZone =>
c.withTimeZone(null)
case _ => e
}

/** Collects adjacent commutative operations. */
private def gatherCommutative(
e: Expression,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -279,7 +279,7 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit
override lazy val resolved: Boolean =
childrenResolved && checkInputDataTypes().isSuccess && (!needsTimeZone || timeZoneId.isDefined)

private[this] def needsTimeZone: Boolean = Cast.needsTimeZone(child.dataType, dataType)
def needsTimeZone: Boolean = Cast.needsTimeZone(child.dataType, dataType)

// [[func]] assumes the input is no longer null because eval already does the null check.
@inline private[this] def buildCast[T](a: Any, func: T => Any): Any = func(a.asInstanceOf[T])
Expand Down Expand Up @@ -1708,6 +1708,7 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit
""")
case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String] = None)
extends CastBase {

override def withTimeZone(timeZoneId: String): TimeZoneAwareExpression =
copy(timeZoneId = Option(timeZoneId))

Expand All @@ -1724,6 +1725,7 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String
*/
case class AnsiCast(child: Expression, dataType: DataType, timeZoneId: Option[String] = None)
extends CastBase {

override def withTimeZone(timeZoneId: String): TimeZoneAwareExpression =
copy(timeZoneId = Option(timeZoneId))

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,12 @@

package org.apache.spark.sql.catalyst.expressions

import java.util.TimeZone

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.plans.logical.Range
import org.apache.spark.sql.types.{IntegerType, StructField, StructType}
import org.apache.spark.sql.types.{IntegerType, LongType, StructField, StructType}

class CanonicalizeSuite extends SparkFunSuite {

Expand Down Expand Up @@ -86,4 +88,11 @@ class CanonicalizeSuite extends SparkFunSuite {
val subExpr = Subtract(range.output.head, Literal(1))
assert(addExpr.canonicalized.hashCode() != subExpr.canonicalized.hashCode())
}

test("SPARK-31515: Canonicalize Cast should consider the value of needTimeZone") {
val literal = Literal(1)
val cast = Cast(literal, LongType)
val castWithTimeZoneId = Cast(literal, LongType, Some(TimeZone.getDefault.getID))
assert(castWithTimeZoneId.semanticEquals(cast))
}
}

0 comments on commit 2ebef75

Please sign in to comment.