Skip to content

Commit

Permalink
name with exact
Browse files Browse the repository at this point in the history
  • Loading branch information
yaooqinn committed Dec 27, 2019
1 parent 7671d83 commit b679381
Show file tree
Hide file tree
Showing 5 changed files with 63 additions and 60 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -77,17 +77,17 @@ case class UnaryMinus(child: Expression) extends UnaryExpression
val iu = IntervalUtils.getClass.getCanonicalName.stripSuffix("$")
defineCodeGen(ctx, ev,
interval => if (checkOverflow) {
s"$iu.negate($interval)"
s"$iu.negateExact($interval)"
} else {
s"$iu.safeNegate($interval)"
s"$iu.negate($interval)"
}
)
}

protected override def nullSafeEval(input: Any): Any = dataType match {
case CalendarIntervalType if checkOverflow =>
IntervalUtils.negate(input.asInstanceOf[CalendarInterval])
case CalendarIntervalType => IntervalUtils.safeNegate(input.asInstanceOf[CalendarInterval])
IntervalUtils.negateExact(input.asInstanceOf[CalendarInterval])
case CalendarIntervalType => IntervalUtils.negate(input.asInstanceOf[CalendarInterval])
case _ => numeric.negate(input)
}

Expand Down Expand Up @@ -232,16 +232,16 @@ case class Add(left: Expression, right: Expression) extends BinaryArithmetic {

override def decimalMethod: String = "$plus"

override def calendarIntervalMethod: String = if (checkOverflow) "add" else "safeAdd"
override def calendarIntervalMethod: String = if (checkOverflow) "addExact" else "add"

private lazy val numeric = TypeUtils.getNumeric(dataType, checkOverflow)

protected override def nullSafeEval(input1: Any, input2: Any): Any = dataType match {
case CalendarIntervalType if checkOverflow =>
IntervalUtils.add(
IntervalUtils.addExact(
input1.asInstanceOf[CalendarInterval], input2.asInstanceOf[CalendarInterval])
case CalendarIntervalType =>
IntervalUtils.safeAdd(
IntervalUtils.add(
input1.asInstanceOf[CalendarInterval], input2.asInstanceOf[CalendarInterval])
case _ => numeric.plus(input1, input2)
}
Expand All @@ -264,16 +264,16 @@ case class Subtract(left: Expression, right: Expression) extends BinaryArithmeti

override def decimalMethod: String = "$minus"

override def calendarIntervalMethod: String = if (checkOverflow) "subtract" else "safeSubtract"
override def calendarIntervalMethod: String = if (checkOverflow) "subtractExact" else "subtract"

private lazy val numeric = TypeUtils.getNumeric(dataType, checkOverflow)

protected override def nullSafeEval(input1: Any, input2: Any): Any = dataType match {
case CalendarIntervalType if checkOverflow =>
IntervalUtils.subtract(
IntervalUtils.subtractExact(
input1.asInstanceOf[CalendarInterval], input2.asInstanceOf[CalendarInterval])
case CalendarIntervalType =>
IntervalUtils.safeSubtract(
IntervalUtils.subtract(
input1.asInstanceOf[CalendarInterval], input2.asInstanceOf[CalendarInterval])
case _ => numeric.minus(input1, input2)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -135,9 +135,9 @@ case class MultiplyInterval(interval: Expression, num: Expression)
override def nullSafeEval(interval: Any, num: Any): Any = {
try {
if (checkOverflow) {
multiply(interval.asInstanceOf[CalendarInterval], num.asInstanceOf[Double])
multiplyExact(interval.asInstanceOf[CalendarInterval], num.asInstanceOf[Double])
} else {
safeMultiply(interval.asInstanceOf[CalendarInterval], num.asInstanceOf[Double])
multiply(interval.asInstanceOf[CalendarInterval], num.asInstanceOf[Double])
}
} catch {
case _: ArithmeticException if !checkOverflow => null
Expand All @@ -147,7 +147,7 @@ case class MultiplyInterval(interval: Expression, num: Expression)
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
nullSafeCodeGen(ctx, ev, (interval, num) => {
val iu = IntervalUtils.getClass.getName.stripSuffix("$")
val operationName = if (checkOverflow) "multiply" else "safeMultiply"
val operationName = if (checkOverflow) "multiplyExact" else "multiply"
s"""
try {
${ev.value} = $iu.$operationName($interval, $num);
Expand All @@ -172,9 +172,9 @@ case class DivideInterval(interval: Expression, num: Expression)
try {
if (num == 0) return null
if (checkOverflow) {
divide(interval.asInstanceOf[CalendarInterval], num.asInstanceOf[Double])
divideExact(interval.asInstanceOf[CalendarInterval], num.asInstanceOf[Double])
} else {
safeDivide(interval.asInstanceOf[CalendarInterval], num.asInstanceOf[Double])
divide(interval.asInstanceOf[CalendarInterval], num.asInstanceOf[Double])
}
} catch {
case _: ArithmeticException if !checkOverflow => null
Expand All @@ -184,7 +184,7 @@ case class DivideInterval(interval: Expression, num: Expression)
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
nullSafeCodeGen(ctx, ev, (interval, num) => {
val iu = IntervalUtils.getClass.getName.stripSuffix("$")
val operationName = if (checkOverflow) "divide" else "safeDivide"
val operationName = if (checkOverflow) "divideExact" else "divide"
s"""
try {
if ($num == 0) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ object IntervalUtils {
assert(input.length == input.trim.length)
input match {
case yearMonthPattern("-", yearStr, monthStr) =>
negate(toInterval(yearStr, monthStr))
negateExact(toInterval(yearStr, monthStr))
case yearMonthPattern(_, yearStr, monthStr) =>
toInterval(yearStr, monthStr)
case _ =>
Expand Down Expand Up @@ -451,7 +451,7 @@ object IntervalUtils {
* @return a new calendar interval instance with all it parameters negated from the origin one.
* @throws ArithmeticException if the result overflows any field value
*/
def negate(interval: CalendarInterval): CalendarInterval = {
def negateExact(interval: CalendarInterval): CalendarInterval = {
val months = Math.negateExact(interval.months)
val days = Math.negateExact(interval.days)
val microseconds = Math.negateExact(interval.microseconds)
Expand All @@ -464,7 +464,7 @@ object IntervalUtils {
* @param interval the interval to be negated
* @return a new calendar interval instance with all it parameters negated from the origin one.
*/
def safeNegate(interval: CalendarInterval): CalendarInterval = {
def negate(interval: CalendarInterval): CalendarInterval = {
new CalendarInterval(-interval.months, -interval.days, -interval.microseconds)
}

Expand All @@ -474,7 +474,7 @@ object IntervalUtils {
* @throws ArithmeticException if the result overflows any field value
*
*/
def add(left: CalendarInterval, right: CalendarInterval): CalendarInterval = {
def addExact(left: CalendarInterval, right: CalendarInterval): CalendarInterval = {
val months = Math.addExact(left.months, right.months)
val days = Math.addExact(left.days, right.days)
val microseconds = Math.addExact(left.microseconds, right.microseconds)
Expand All @@ -484,7 +484,7 @@ object IntervalUtils {
/**
* Return a new calendar interval instance of the sum of two intervals.
*/
def safeAdd(left: CalendarInterval, right: CalendarInterval): CalendarInterval = {
def add(left: CalendarInterval, right: CalendarInterval): CalendarInterval = {
val months = left.months + right.months
val days = left.days + right.days
val microseconds = left.microseconds + right.microseconds
Expand All @@ -497,7 +497,7 @@ object IntervalUtils {
* @throws ArithmeticException if the result overflows any field value
*
*/
def subtract(left: CalendarInterval, right: CalendarInterval): CalendarInterval = {
def subtractExact(left: CalendarInterval, right: CalendarInterval): CalendarInterval = {
val months = Math.subtractExact(left.months, right.months)
val days = Math.subtractExact(left.days, right.days)
val microseconds = Math.subtractExact(left.microseconds, right.microseconds)
Expand All @@ -507,7 +507,7 @@ object IntervalUtils {
/**
* Return a new calendar interval instance of the left interval minus the right one.
*/
def safeSubtract(left: CalendarInterval, right: CalendarInterval): CalendarInterval = {
def subtract(left: CalendarInterval, right: CalendarInterval): CalendarInterval = {
val months = left.months - right.months
val days = left.days - right.days
val microseconds = left.microseconds - right.microseconds
Expand All @@ -519,14 +519,14 @@ object IntervalUtils {
*
* @throws ArithmeticException if the result overflows any field value
*/
def multiply(interval: CalendarInterval, num: Double): CalendarInterval = {
def multiplyExact(interval: CalendarInterval, num: Double): CalendarInterval = {
fromDoubles(num * interval.months, num * interval.days, num * interval.microseconds)
}

/**
* Return a new calendar interval instance of the left interval times a multiplier.
*/
def safeMultiply(interval: CalendarInterval, num: Double): CalendarInterval = {
def multiply(interval: CalendarInterval, num: Double): CalendarInterval = {
safeFromDoubles(num * interval.months, num * interval.days, num * interval.microseconds)
}

Expand All @@ -535,7 +535,7 @@ object IntervalUtils {
*
* @throws ArithmeticException if the result overflows any field value or divided by zero
*/
def divide(interval: CalendarInterval, num: Double): CalendarInterval = {
def divideExact(interval: CalendarInterval, num: Double): CalendarInterval = {
if (num == 0) throw new ArithmeticException("divide by zero")
fromDoubles(interval.months / num, interval.days / num, interval.microseconds / num)
}
Expand All @@ -545,7 +545,7 @@ object IntervalUtils {
*
* @throws ArithmeticException if divided by zero
*/
def safeDivide(interval: CalendarInterval, num: Double): CalendarInterval = {
def divide(interval: CalendarInterval, num: Double): CalendarInterval = {
if (num == 0) throw new ArithmeticException("divide by zero")
safeFromDoubles(interval.months / num, interval.days / num, interval.microseconds / num)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -733,7 +733,7 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper
checkEvaluation(new Sequence(
Literal(Timestamp.valueOf("2018-01-02 00:00:00")),
Literal(Timestamp.valueOf("2018-01-01 00:00:00")),
Literal(negate(stringToInterval("interval 12 hours")))),
Literal(negateExact(stringToInterval("interval 12 hours")))),
Seq(
Timestamp.valueOf("2018-01-02 00:00:00"),
Timestamp.valueOf("2018-01-01 12:00:00"),
Expand All @@ -742,7 +742,7 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper
checkEvaluation(new Sequence(
Literal(Timestamp.valueOf("2018-01-02 00:00:00")),
Literal(Timestamp.valueOf("2017-12-31 23:59:59")),
Literal(negate(stringToInterval("interval 12 hours")))),
Literal(negateExact(stringToInterval("interval 12 hours")))),
Seq(
Timestamp.valueOf("2018-01-02 00:00:00"),
Timestamp.valueOf("2018-01-01 12:00:00"),
Expand All @@ -760,7 +760,7 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper
checkEvaluation(new Sequence(
Literal(Timestamp.valueOf("2018-03-01 00:00:00")),
Literal(Timestamp.valueOf("2018-01-01 00:00:00")),
Literal(negate(stringToInterval("interval 1 month")))),
Literal(negateExact(stringToInterval("interval 1 month")))),
Seq(
Timestamp.valueOf("2018-03-01 00:00:00"),
Timestamp.valueOf("2018-02-01 00:00:00"),
Expand All @@ -769,7 +769,7 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper
checkEvaluation(new Sequence(
Literal(Timestamp.valueOf("2018-03-03 00:00:00")),
Literal(Timestamp.valueOf("2018-01-01 00:00:00")),
Literal(negate(stringToInterval("interval 1 month 1 day")))),
Literal(negateExact(stringToInterval("interval 1 month 1 day")))),
Seq(
Timestamp.valueOf("2018-03-03 00:00:00"),
Timestamp.valueOf("2018-02-02 00:00:00"),
Expand Down Expand Up @@ -815,7 +815,7 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper
checkEvaluation(new Sequence(
Literal(Timestamp.valueOf("2022-04-01 00:00:00")),
Literal(Timestamp.valueOf("2017-01-01 00:00:00")),
Literal(negate(fromYearMonthString("1-5")))),
Literal(negateExact(fromYearMonthString("1-5")))),
Seq(
Timestamp.valueOf("2022-04-01 00:00:00.000"),
Timestamp.valueOf("2020-11-01 00:00:00.000"),
Expand Down Expand Up @@ -907,7 +907,7 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper
new Sequence(
Literal(Date.valueOf("1970-01-01")),
Literal(Date.valueOf("1970-02-01")),
Literal(negate(stringToInterval("interval 1 month")))),
Literal(negateExact(stringToInterval("interval 1 month")))),
EmptyRow,
s"sequence boundaries: 0 to 2678400000000 by -${28 * MICROS_PER_DAY}")
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -239,16 +239,16 @@ class IntervalUtilsSuite extends SparkFunSuite with SQLHelper {
}

test("negate") {
assert(negateExact(new CalendarInterval(1, 2, 3)) === new CalendarInterval(-1, -2, -3))
assert(negate(new CalendarInterval(1, 2, 3)) === new CalendarInterval(-1, -2, -3))
assert(safeNegate(new CalendarInterval(1, 2, 3)) === new CalendarInterval(-1, -2, -3))
}

test("subtract one interval by another") {
val input1 = new CalendarInterval(3, 1, 1 * MICROS_PER_HOUR)
val input2 = new CalendarInterval(2, 4, 100 * MICROS_PER_HOUR)
val input3 = new CalendarInterval(-10, -30, -81 * MICROS_PER_HOUR)
val input4 = new CalendarInterval(75, 150, 200 * MICROS_PER_HOUR)
Seq[(CalendarInterval, CalendarInterval) => CalendarInterval](subtract, safeSubtract)
Seq[(CalendarInterval, CalendarInterval) => CalendarInterval](subtractExact, subtract)
.foreach { func =>
assert(new CalendarInterval(1, -3, -99 * MICROS_PER_HOUR) === func(input1, input2))
assert(new CalendarInterval(-85, -180, -281 * MICROS_PER_HOUR) === func(input3, input4))
Expand All @@ -260,14 +260,14 @@ class IntervalUtilsSuite extends SparkFunSuite with SQLHelper {
val input2 = new CalendarInterval(2, 4, 100 * MICROS_PER_HOUR)
val input3 = new CalendarInterval(-10, -30, -81 * MICROS_PER_HOUR)
val input4 = new CalendarInterval(75, 150, 200 * MICROS_PER_HOUR)
Seq[(CalendarInterval, CalendarInterval) => CalendarInterval](add, safeAdd).foreach { func =>
Seq[(CalendarInterval, CalendarInterval) => CalendarInterval](addExact, add).foreach { func =>
assert(new CalendarInterval(5, 5, 101 * MICROS_PER_HOUR) === func(input1, input2))
assert(new CalendarInterval(65, 120, 119 * MICROS_PER_HOUR) === func(input3, input4))
}
}

test("multiply by num") {
Seq[(CalendarInterval, Double) => CalendarInterval](multiply, safeMultiply).foreach { func =>
Seq[(CalendarInterval, Double) => CalendarInterval](multiplyExact, multiply).foreach { func =>
var interval = new CalendarInterval(0, 0, 0)
assert(interval === func(interval, 0))
interval = new CalendarInterval(123, 456, 789)
Expand All @@ -281,17 +281,17 @@ class IntervalUtilsSuite extends SparkFunSuite with SQLHelper {
}

assert(CalendarInterval.MAX_VALUE ===
safeMultiply(new CalendarInterval(2, 0, 0), Integer.MAX_VALUE))
multiply(new CalendarInterval(2, 0, 0), Integer.MAX_VALUE))
try {
multiply(new CalendarInterval(2, 0, 0), Integer.MAX_VALUE)
multiplyExact(new CalendarInterval(2, 0, 0), Integer.MAX_VALUE)
fail("Expected to throw an exception on months overflow")
} catch {
case e: ArithmeticException => assert(e.getMessage.contains("overflow"))
}
}

test("divide by num") {
Seq[(CalendarInterval, Double) => CalendarInterval](divide, safeDivide).foreach { func =>
Seq[(CalendarInterval, Double) => CalendarInterval](divideExact, divide).foreach { func =>
var interval = new CalendarInterval(0, 0, 0)
assert(interval === func(interval, 10))
interval = new CalendarInterval(1, 3, 30 * MICROS_PER_SECOND)
Expand Down Expand Up @@ -457,37 +457,40 @@ class IntervalUtilsSuite extends SparkFunSuite with SQLHelper {
}

test("interval overflow check") {
intercept[ArithmeticException](negate(new CalendarInterval(Int.MinValue, 0, 0)))
assert(safeNegate(new CalendarInterval(Int.MinValue, 0, 0)) ===
intercept[ArithmeticException](negateExact(new CalendarInterval(Int.MinValue, 0, 0)))
assert(negate(new CalendarInterval(Int.MinValue, 0, 0)) ===
new CalendarInterval(Int.MinValue, 0, 0))
intercept[ArithmeticException](negate(CalendarInterval.MIN_VALUE))
assert(safeNegate(CalendarInterval.MIN_VALUE) === CalendarInterval.MIN_VALUE)
intercept[ArithmeticException](add(CalendarInterval.MAX_VALUE, new CalendarInterval(0, 0, 1)))
intercept[ArithmeticException](add(CalendarInterval.MAX_VALUE, new CalendarInterval(0, 1, 0)))
intercept[ArithmeticException](add(CalendarInterval.MAX_VALUE, new CalendarInterval(1, 0, 0)))
assert(safeAdd(CalendarInterval.MAX_VALUE, new CalendarInterval(0, 0, 1)) ===
intercept[ArithmeticException](negateExact(CalendarInterval.MIN_VALUE))
assert(negate(CalendarInterval.MIN_VALUE) === CalendarInterval.MIN_VALUE)
intercept[ArithmeticException](addExact(CalendarInterval.MAX_VALUE,
new CalendarInterval(0, 0, 1)))
intercept[ArithmeticException](addExact(CalendarInterval.MAX_VALUE,
new CalendarInterval(0, 1, 0)))
intercept[ArithmeticException](addExact(CalendarInterval.MAX_VALUE,
new CalendarInterval(1, 0, 0)))
assert(add(CalendarInterval.MAX_VALUE, new CalendarInterval(0, 0, 1)) ===
new CalendarInterval(Int.MaxValue, Int.MaxValue, Long.MinValue))
assert(safeAdd(CalendarInterval.MAX_VALUE, new CalendarInterval(0, 1, 0)) ===
assert(add(CalendarInterval.MAX_VALUE, new CalendarInterval(0, 1, 0)) ===
new CalendarInterval(Int.MaxValue, Int.MinValue, Long.MaxValue))
assert(safeAdd(CalendarInterval.MAX_VALUE, new CalendarInterval(1, 0, 0)) ===
assert(add(CalendarInterval.MAX_VALUE, new CalendarInterval(1, 0, 0)) ===
new CalendarInterval(Int.MinValue, Int.MaxValue, Long.MaxValue))

intercept[ArithmeticException](subtract(CalendarInterval.MAX_VALUE,
intercept[ArithmeticException](subtractExact(CalendarInterval.MAX_VALUE,
new CalendarInterval(0, 0, -1)))
intercept[ArithmeticException](subtract(CalendarInterval.MAX_VALUE,
intercept[ArithmeticException](subtractExact(CalendarInterval.MAX_VALUE,
new CalendarInterval(0, -1, 0)))
intercept[ArithmeticException](subtract(CalendarInterval.MAX_VALUE,
intercept[ArithmeticException](subtractExact(CalendarInterval.MAX_VALUE,
new CalendarInterval(-1, 0, 0)))
assert(safeSubtract(CalendarInterval.MAX_VALUE, new CalendarInterval(0, 0, -1)) ===
assert(subtract(CalendarInterval.MAX_VALUE, new CalendarInterval(0, 0, -1)) ===
new CalendarInterval(Int.MaxValue, Int.MaxValue, Long.MinValue))
assert(safeSubtract(CalendarInterval.MAX_VALUE, new CalendarInterval(0, -1, 0)) ===
assert(subtract(CalendarInterval.MAX_VALUE, new CalendarInterval(0, -1, 0)) ===
new CalendarInterval(Int.MaxValue, Int.MinValue, Long.MaxValue))
assert(safeSubtract(CalendarInterval.MAX_VALUE, new CalendarInterval(-1, 0, 0)) ===
assert(subtract(CalendarInterval.MAX_VALUE, new CalendarInterval(-1, 0, 0)) ===
new CalendarInterval(Int.MinValue, Int.MaxValue, Long.MaxValue))

intercept[ArithmeticException](multiply(CalendarInterval.MAX_VALUE, 2))
assert(safeMultiply(CalendarInterval.MAX_VALUE, 2) === CalendarInterval.MAX_VALUE)
intercept[ArithmeticException](divide(CalendarInterval.MAX_VALUE, 0.5))
assert(safeDivide(CalendarInterval.MAX_VALUE, 0.5) === CalendarInterval.MAX_VALUE)
intercept[ArithmeticException](multiplyExact(CalendarInterval.MAX_VALUE, 2))
assert(multiply(CalendarInterval.MAX_VALUE, 2) === CalendarInterval.MAX_VALUE)
intercept[ArithmeticException](divideExact(CalendarInterval.MAX_VALUE, 0.5))
assert(divide(CalendarInterval.MAX_VALUE, 0.5) === CalendarInterval.MAX_VALUE)
}
}

0 comments on commit b679381

Please sign in to comment.