diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala index f12892f411756..cb3b16d751e4c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala @@ -77,17 +77,17 @@ case class UnaryMinus(child: Expression) extends UnaryExpression val iu = IntervalUtils.getClass.getCanonicalName.stripSuffix("$") defineCodeGen(ctx, ev, interval => if (checkOverflow) { - s"$iu.negate($interval)" + s"$iu.negateExact($interval)" } else { - s"$iu.safeNegate($interval)" + s"$iu.negate($interval)" } ) } protected override def nullSafeEval(input: Any): Any = dataType match { case CalendarIntervalType if checkOverflow => - IntervalUtils.negate(input.asInstanceOf[CalendarInterval]) - case CalendarIntervalType => IntervalUtils.safeNegate(input.asInstanceOf[CalendarInterval]) + IntervalUtils.negateExact(input.asInstanceOf[CalendarInterval]) + case CalendarIntervalType => IntervalUtils.negate(input.asInstanceOf[CalendarInterval]) case _ => numeric.negate(input) } @@ -232,16 +232,16 @@ case class Add(left: Expression, right: Expression) extends BinaryArithmetic { override def decimalMethod: String = "$plus" - override def calendarIntervalMethod: String = if (checkOverflow) "add" else "safeAdd" + override def calendarIntervalMethod: String = if (checkOverflow) "addExact" else "add" private lazy val numeric = TypeUtils.getNumeric(dataType, checkOverflow) protected override def nullSafeEval(input1: Any, input2: Any): Any = dataType match { case CalendarIntervalType if checkOverflow => - IntervalUtils.add( + IntervalUtils.addExact( input1.asInstanceOf[CalendarInterval], input2.asInstanceOf[CalendarInterval]) case CalendarIntervalType => - IntervalUtils.safeAdd( + IntervalUtils.add( input1.asInstanceOf[CalendarInterval], input2.asInstanceOf[CalendarInterval]) case _ => numeric.plus(input1, input2) } @@ -264,16 +264,16 @@ case class Subtract(left: Expression, right: Expression) extends BinaryArithmeti override def decimalMethod: String = "$minus" - override def calendarIntervalMethod: String = if (checkOverflow) "subtract" else "safeSubtract" + override def calendarIntervalMethod: String = if (checkOverflow) "subtractExact" else "subtract" private lazy val numeric = TypeUtils.getNumeric(dataType, checkOverflow) protected override def nullSafeEval(input1: Any, input2: Any): Any = dataType match { case CalendarIntervalType if checkOverflow => - IntervalUtils.subtract( + IntervalUtils.subtractExact( input1.asInstanceOf[CalendarInterval], input2.asInstanceOf[CalendarInterval]) case CalendarIntervalType => - IntervalUtils.safeSubtract( + IntervalUtils.subtract( input1.asInstanceOf[CalendarInterval], input2.asInstanceOf[CalendarInterval]) case _ => numeric.minus(input1, input2) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/intervalExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/intervalExpressions.scala index f0793ccb78dd8..7a38c9c76f31f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/intervalExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/intervalExpressions.scala @@ -135,9 +135,9 @@ case class MultiplyInterval(interval: Expression, num: Expression) override def nullSafeEval(interval: Any, num: Any): Any = { try { if (checkOverflow) { - multiply(interval.asInstanceOf[CalendarInterval], num.asInstanceOf[Double]) + multiplyExact(interval.asInstanceOf[CalendarInterval], num.asInstanceOf[Double]) } else { - safeMultiply(interval.asInstanceOf[CalendarInterval], num.asInstanceOf[Double]) + multiply(interval.asInstanceOf[CalendarInterval], num.asInstanceOf[Double]) } } catch { case _: ArithmeticException if !checkOverflow => null @@ -147,7 +147,7 @@ case class MultiplyInterval(interval: Expression, num: Expression) override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { nullSafeCodeGen(ctx, ev, (interval, num) => { val iu = IntervalUtils.getClass.getName.stripSuffix("$") - val operationName = if (checkOverflow) "multiply" else "safeMultiply" + val operationName = if (checkOverflow) "multiplyExact" else "multiply" s""" try { ${ev.value} = $iu.$operationName($interval, $num); @@ -172,9 +172,9 @@ case class DivideInterval(interval: Expression, num: Expression) try { if (num == 0) return null if (checkOverflow) { - divide(interval.asInstanceOf[CalendarInterval], num.asInstanceOf[Double]) + divideExact(interval.asInstanceOf[CalendarInterval], num.asInstanceOf[Double]) } else { - safeDivide(interval.asInstanceOf[CalendarInterval], num.asInstanceOf[Double]) + divide(interval.asInstanceOf[CalendarInterval], num.asInstanceOf[Double]) } } catch { case _: ArithmeticException if !checkOverflow => null @@ -184,7 +184,7 @@ case class DivideInterval(interval: Expression, num: Expression) override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { nullSafeCodeGen(ctx, ev, (interval, num) => { val iu = IntervalUtils.getClass.getName.stripSuffix("$") - val operationName = if (checkOverflow) "divide" else "safeDivide" + val operationName = if (checkOverflow) "divideExact" else "divide" s""" try { if ($num == 0) { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/IntervalUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/IntervalUtils.scala index b1eec68a4395b..757fded9395cf 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/IntervalUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/IntervalUtils.scala @@ -138,7 +138,7 @@ object IntervalUtils { assert(input.length == input.trim.length) input match { case yearMonthPattern("-", yearStr, monthStr) => - negate(toInterval(yearStr, monthStr)) + negateExact(toInterval(yearStr, monthStr)) case yearMonthPattern(_, yearStr, monthStr) => toInterval(yearStr, monthStr) case _ => @@ -451,7 +451,7 @@ object IntervalUtils { * @return a new calendar interval instance with all it parameters negated from the origin one. * @throws ArithmeticException if the result overflows any field value */ - def negate(interval: CalendarInterval): CalendarInterval = { + def negateExact(interval: CalendarInterval): CalendarInterval = { val months = Math.negateExact(interval.months) val days = Math.negateExact(interval.days) val microseconds = Math.negateExact(interval.microseconds) @@ -464,7 +464,7 @@ object IntervalUtils { * @param interval the interval to be negated * @return a new calendar interval instance with all it parameters negated from the origin one. */ - def safeNegate(interval: CalendarInterval): CalendarInterval = { + def negate(interval: CalendarInterval): CalendarInterval = { new CalendarInterval(-interval.months, -interval.days, -interval.microseconds) } @@ -474,7 +474,7 @@ object IntervalUtils { * @throws ArithmeticException if the result overflows any field value * */ - def add(left: CalendarInterval, right: CalendarInterval): CalendarInterval = { + def addExact(left: CalendarInterval, right: CalendarInterval): CalendarInterval = { val months = Math.addExact(left.months, right.months) val days = Math.addExact(left.days, right.days) val microseconds = Math.addExact(left.microseconds, right.microseconds) @@ -484,7 +484,7 @@ object IntervalUtils { /** * Return a new calendar interval instance of the sum of two intervals. */ - def safeAdd(left: CalendarInterval, right: CalendarInterval): CalendarInterval = { + def add(left: CalendarInterval, right: CalendarInterval): CalendarInterval = { val months = left.months + right.months val days = left.days + right.days val microseconds = left.microseconds + right.microseconds @@ -497,7 +497,7 @@ object IntervalUtils { * @throws ArithmeticException if the result overflows any field value * */ - def subtract(left: CalendarInterval, right: CalendarInterval): CalendarInterval = { + def subtractExact(left: CalendarInterval, right: CalendarInterval): CalendarInterval = { val months = Math.subtractExact(left.months, right.months) val days = Math.subtractExact(left.days, right.days) val microseconds = Math.subtractExact(left.microseconds, right.microseconds) @@ -507,7 +507,7 @@ object IntervalUtils { /** * Return a new calendar interval instance of the left interval minus the right one. */ - def safeSubtract(left: CalendarInterval, right: CalendarInterval): CalendarInterval = { + def subtract(left: CalendarInterval, right: CalendarInterval): CalendarInterval = { val months = left.months - right.months val days = left.days - right.days val microseconds = left.microseconds - right.microseconds @@ -519,14 +519,14 @@ object IntervalUtils { * * @throws ArithmeticException if the result overflows any field value */ - def multiply(interval: CalendarInterval, num: Double): CalendarInterval = { + def multiplyExact(interval: CalendarInterval, num: Double): CalendarInterval = { fromDoubles(num * interval.months, num * interval.days, num * interval.microseconds) } /** * Return a new calendar interval instance of the left interval times a multiplier. */ - def safeMultiply(interval: CalendarInterval, num: Double): CalendarInterval = { + def multiply(interval: CalendarInterval, num: Double): CalendarInterval = { safeFromDoubles(num * interval.months, num * interval.days, num * interval.microseconds) } @@ -535,7 +535,7 @@ object IntervalUtils { * * @throws ArithmeticException if the result overflows any field value or divided by zero */ - def divide(interval: CalendarInterval, num: Double): CalendarInterval = { + def divideExact(interval: CalendarInterval, num: Double): CalendarInterval = { if (num == 0) throw new ArithmeticException("divide by zero") fromDoubles(interval.months / num, interval.days / num, interval.microseconds / num) } @@ -545,7 +545,7 @@ object IntervalUtils { * * @throws ArithmeticException if divided by zero */ - def safeDivide(interval: CalendarInterval, num: Double): CalendarInterval = { + def divide(interval: CalendarInterval, num: Double): CalendarInterval = { if (num == 0) throw new ArithmeticException("divide by zero") safeFromDoubles(interval.months / num, interval.days / num, interval.microseconds / num) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala index cc9ebfe409426..9e98e146c7a0e 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala @@ -733,7 +733,7 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(new Sequence( Literal(Timestamp.valueOf("2018-01-02 00:00:00")), Literal(Timestamp.valueOf("2018-01-01 00:00:00")), - Literal(negate(stringToInterval("interval 12 hours")))), + Literal(negateExact(stringToInterval("interval 12 hours")))), Seq( Timestamp.valueOf("2018-01-02 00:00:00"), Timestamp.valueOf("2018-01-01 12:00:00"), @@ -742,7 +742,7 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(new Sequence( Literal(Timestamp.valueOf("2018-01-02 00:00:00")), Literal(Timestamp.valueOf("2017-12-31 23:59:59")), - Literal(negate(stringToInterval("interval 12 hours")))), + Literal(negateExact(stringToInterval("interval 12 hours")))), Seq( Timestamp.valueOf("2018-01-02 00:00:00"), Timestamp.valueOf("2018-01-01 12:00:00"), @@ -760,7 +760,7 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(new Sequence( Literal(Timestamp.valueOf("2018-03-01 00:00:00")), Literal(Timestamp.valueOf("2018-01-01 00:00:00")), - Literal(negate(stringToInterval("interval 1 month")))), + Literal(negateExact(stringToInterval("interval 1 month")))), Seq( Timestamp.valueOf("2018-03-01 00:00:00"), Timestamp.valueOf("2018-02-01 00:00:00"), @@ -769,7 +769,7 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(new Sequence( Literal(Timestamp.valueOf("2018-03-03 00:00:00")), Literal(Timestamp.valueOf("2018-01-01 00:00:00")), - Literal(negate(stringToInterval("interval 1 month 1 day")))), + Literal(negateExact(stringToInterval("interval 1 month 1 day")))), Seq( Timestamp.valueOf("2018-03-03 00:00:00"), Timestamp.valueOf("2018-02-02 00:00:00"), @@ -815,7 +815,7 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(new Sequence( Literal(Timestamp.valueOf("2022-04-01 00:00:00")), Literal(Timestamp.valueOf("2017-01-01 00:00:00")), - Literal(negate(fromYearMonthString("1-5")))), + Literal(negateExact(fromYearMonthString("1-5")))), Seq( Timestamp.valueOf("2022-04-01 00:00:00.000"), Timestamp.valueOf("2020-11-01 00:00:00.000"), @@ -907,7 +907,7 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper new Sequence( Literal(Date.valueOf("1970-01-01")), Literal(Date.valueOf("1970-02-01")), - Literal(negate(stringToInterval("interval 1 month")))), + Literal(negateExact(stringToInterval("interval 1 month")))), EmptyRow, s"sequence boundaries: 0 to 2678400000000 by -${28 * MICROS_PER_DAY}") } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/IntervalUtilsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/IntervalUtilsSuite.scala index 2669406cc3487..7aba12b1c9fba 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/IntervalUtilsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/IntervalUtilsSuite.scala @@ -239,8 +239,8 @@ class IntervalUtilsSuite extends SparkFunSuite with SQLHelper { } test("negate") { + assert(negateExact(new CalendarInterval(1, 2, 3)) === new CalendarInterval(-1, -2, -3)) assert(negate(new CalendarInterval(1, 2, 3)) === new CalendarInterval(-1, -2, -3)) - assert(safeNegate(new CalendarInterval(1, 2, 3)) === new CalendarInterval(-1, -2, -3)) } test("subtract one interval by another") { @@ -248,7 +248,7 @@ class IntervalUtilsSuite extends SparkFunSuite with SQLHelper { val input2 = new CalendarInterval(2, 4, 100 * MICROS_PER_HOUR) val input3 = new CalendarInterval(-10, -30, -81 * MICROS_PER_HOUR) val input4 = new CalendarInterval(75, 150, 200 * MICROS_PER_HOUR) - Seq[(CalendarInterval, CalendarInterval) => CalendarInterval](subtract, safeSubtract) + Seq[(CalendarInterval, CalendarInterval) => CalendarInterval](subtractExact, subtract) .foreach { func => assert(new CalendarInterval(1, -3, -99 * MICROS_PER_HOUR) === func(input1, input2)) assert(new CalendarInterval(-85, -180, -281 * MICROS_PER_HOUR) === func(input3, input4)) @@ -260,14 +260,14 @@ class IntervalUtilsSuite extends SparkFunSuite with SQLHelper { val input2 = new CalendarInterval(2, 4, 100 * MICROS_PER_HOUR) val input3 = new CalendarInterval(-10, -30, -81 * MICROS_PER_HOUR) val input4 = new CalendarInterval(75, 150, 200 * MICROS_PER_HOUR) - Seq[(CalendarInterval, CalendarInterval) => CalendarInterval](add, safeAdd).foreach { func => + Seq[(CalendarInterval, CalendarInterval) => CalendarInterval](addExact, add).foreach { func => assert(new CalendarInterval(5, 5, 101 * MICROS_PER_HOUR) === func(input1, input2)) assert(new CalendarInterval(65, 120, 119 * MICROS_PER_HOUR) === func(input3, input4)) } } test("multiply by num") { - Seq[(CalendarInterval, Double) => CalendarInterval](multiply, safeMultiply).foreach { func => + Seq[(CalendarInterval, Double) => CalendarInterval](multiplyExact, multiply).foreach { func => var interval = new CalendarInterval(0, 0, 0) assert(interval === func(interval, 0)) interval = new CalendarInterval(123, 456, 789) @@ -281,9 +281,9 @@ class IntervalUtilsSuite extends SparkFunSuite with SQLHelper { } assert(CalendarInterval.MAX_VALUE === - safeMultiply(new CalendarInterval(2, 0, 0), Integer.MAX_VALUE)) + multiply(new CalendarInterval(2, 0, 0), Integer.MAX_VALUE)) try { - multiply(new CalendarInterval(2, 0, 0), Integer.MAX_VALUE) + multiplyExact(new CalendarInterval(2, 0, 0), Integer.MAX_VALUE) fail("Expected to throw an exception on months overflow") } catch { case e: ArithmeticException => assert(e.getMessage.contains("overflow")) @@ -291,7 +291,7 @@ class IntervalUtilsSuite extends SparkFunSuite with SQLHelper { } test("divide by num") { - Seq[(CalendarInterval, Double) => CalendarInterval](divide, safeDivide).foreach { func => + Seq[(CalendarInterval, Double) => CalendarInterval](divideExact, divide).foreach { func => var interval = new CalendarInterval(0, 0, 0) assert(interval === func(interval, 10)) interval = new CalendarInterval(1, 3, 30 * MICROS_PER_SECOND) @@ -457,37 +457,40 @@ class IntervalUtilsSuite extends SparkFunSuite with SQLHelper { } test("interval overflow check") { - intercept[ArithmeticException](negate(new CalendarInterval(Int.MinValue, 0, 0))) - assert(safeNegate(new CalendarInterval(Int.MinValue, 0, 0)) === + intercept[ArithmeticException](negateExact(new CalendarInterval(Int.MinValue, 0, 0))) + assert(negate(new CalendarInterval(Int.MinValue, 0, 0)) === new CalendarInterval(Int.MinValue, 0, 0)) - intercept[ArithmeticException](negate(CalendarInterval.MIN_VALUE)) - assert(safeNegate(CalendarInterval.MIN_VALUE) === CalendarInterval.MIN_VALUE) - intercept[ArithmeticException](add(CalendarInterval.MAX_VALUE, new CalendarInterval(0, 0, 1))) - intercept[ArithmeticException](add(CalendarInterval.MAX_VALUE, new CalendarInterval(0, 1, 0))) - intercept[ArithmeticException](add(CalendarInterval.MAX_VALUE, new CalendarInterval(1, 0, 0))) - assert(safeAdd(CalendarInterval.MAX_VALUE, new CalendarInterval(0, 0, 1)) === + intercept[ArithmeticException](negateExact(CalendarInterval.MIN_VALUE)) + assert(negate(CalendarInterval.MIN_VALUE) === CalendarInterval.MIN_VALUE) + intercept[ArithmeticException](addExact(CalendarInterval.MAX_VALUE, + new CalendarInterval(0, 0, 1))) + intercept[ArithmeticException](addExact(CalendarInterval.MAX_VALUE, + new CalendarInterval(0, 1, 0))) + intercept[ArithmeticException](addExact(CalendarInterval.MAX_VALUE, + new CalendarInterval(1, 0, 0))) + assert(add(CalendarInterval.MAX_VALUE, new CalendarInterval(0, 0, 1)) === new CalendarInterval(Int.MaxValue, Int.MaxValue, Long.MinValue)) - assert(safeAdd(CalendarInterval.MAX_VALUE, new CalendarInterval(0, 1, 0)) === + assert(add(CalendarInterval.MAX_VALUE, new CalendarInterval(0, 1, 0)) === new CalendarInterval(Int.MaxValue, Int.MinValue, Long.MaxValue)) - assert(safeAdd(CalendarInterval.MAX_VALUE, new CalendarInterval(1, 0, 0)) === + assert(add(CalendarInterval.MAX_VALUE, new CalendarInterval(1, 0, 0)) === new CalendarInterval(Int.MinValue, Int.MaxValue, Long.MaxValue)) - intercept[ArithmeticException](subtract(CalendarInterval.MAX_VALUE, + intercept[ArithmeticException](subtractExact(CalendarInterval.MAX_VALUE, new CalendarInterval(0, 0, -1))) - intercept[ArithmeticException](subtract(CalendarInterval.MAX_VALUE, + intercept[ArithmeticException](subtractExact(CalendarInterval.MAX_VALUE, new CalendarInterval(0, -1, 0))) - intercept[ArithmeticException](subtract(CalendarInterval.MAX_VALUE, + intercept[ArithmeticException](subtractExact(CalendarInterval.MAX_VALUE, new CalendarInterval(-1, 0, 0))) - assert(safeSubtract(CalendarInterval.MAX_VALUE, new CalendarInterval(0, 0, -1)) === + assert(subtract(CalendarInterval.MAX_VALUE, new CalendarInterval(0, 0, -1)) === new CalendarInterval(Int.MaxValue, Int.MaxValue, Long.MinValue)) - assert(safeSubtract(CalendarInterval.MAX_VALUE, new CalendarInterval(0, -1, 0)) === + assert(subtract(CalendarInterval.MAX_VALUE, new CalendarInterval(0, -1, 0)) === new CalendarInterval(Int.MaxValue, Int.MinValue, Long.MaxValue)) - assert(safeSubtract(CalendarInterval.MAX_VALUE, new CalendarInterval(-1, 0, 0)) === + assert(subtract(CalendarInterval.MAX_VALUE, new CalendarInterval(-1, 0, 0)) === new CalendarInterval(Int.MinValue, Int.MaxValue, Long.MaxValue)) - intercept[ArithmeticException](multiply(CalendarInterval.MAX_VALUE, 2)) - assert(safeMultiply(CalendarInterval.MAX_VALUE, 2) === CalendarInterval.MAX_VALUE) - intercept[ArithmeticException](divide(CalendarInterval.MAX_VALUE, 0.5)) - assert(safeDivide(CalendarInterval.MAX_VALUE, 0.5) === CalendarInterval.MAX_VALUE) + intercept[ArithmeticException](multiplyExact(CalendarInterval.MAX_VALUE, 2)) + assert(multiply(CalendarInterval.MAX_VALUE, 2) === CalendarInterval.MAX_VALUE) + intercept[ArithmeticException](divideExact(CalendarInterval.MAX_VALUE, 0.5)) + assert(divide(CalendarInterval.MAX_VALUE, 0.5) === CalendarInterval.MAX_VALUE) } }