Skip to content

Commit

Permalink
Move the type promotion to the rule PromoteStrings.
Browse files Browse the repository at this point in the history
  • Loading branch information
wangyum committed Jun 16, 2017
1 parent 5f5fee6 commit afdb40d
Show file tree
Hide file tree
Showing 6 changed files with 19 additions and 20 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -357,6 +357,7 @@ object TypeCoercion {
val commonType = findCommonTypeForBinaryComparison(left.dataType, right.dataType).get
p.makeCopy(Array(castExpr(left, commonType), castExpr(right, commonType)))

case Abs(e @ StringType()) => Abs(Cast(e, DoubleType))
case Sum(e @ StringType()) => Sum(Cast(e, DoubleType))
case Average(e @ StringType()) => Average(Cast(e, DoubleType))
case StddevPop(e @ StringType()) => StddevPop(Cast(e, DoubleType))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.util.TypeUtils
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String}
import org.apache.spark.unsafe.types.CalendarInterval

@ExpressionDescription(
usage = "_FUNC_(expr) - Returns the negated value of `expr`.",
Expand Down Expand Up @@ -97,30 +97,20 @@ case class UnaryPositive(child: Expression)
case class Abs(child: Expression)
extends UnaryExpression with ExpectsInputTypes with NullIntolerant {

override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection(NumericType, StringType))
override def inputTypes: Seq[AbstractDataType] = Seq(NumericType)

override def dataType: DataType = child.dataType match {
case dt: NumericType => dt
case dt: StringType => DoubleType
}
override def dataType: DataType = child.dataType

private lazy val numeric = TypeUtils.getNumeric(dataType)

override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = child.dataType match {
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = dataType match {
case dt: DecimalType =>
defineCodeGen(ctx, ev, c => s"$c.abs()")
case dt: NumericType =>
defineCodeGen(ctx, ev, c => s"(${ctx.javaType(dt)})(java.lang.Math.abs($c))")
case dt: StringType =>
defineCodeGen(ctx, ev, c => s"java.lang.Math.abs(Double.valueOf($c.toString()))")
}

protected override def nullSafeEval(input: Any): Any = child.dataType match {
case StringType =>
numeric.abs(input.asInstanceOf[UTF8String].toString.toDouble)
case _: NumericType =>
numeric.abs(input)
}
protected override def nullSafeEval(input: Any): Any = numeric.abs(input)
}

abstract class BinaryArithmetic extends BinaryOperator with NullIntolerant {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite {

test("check types for unary arithmetic") {
assertError(UnaryMinus('stringField), "(numeric or calendarinterval) type")
assertError(Abs('stringField), "requires numeric type")
assertError(BitwiseNot('stringField), "requires integral type")
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -205,10 +205,6 @@ class ArithmeticExpressionSuite extends SparkFunSuite with ExpressionEvalHelper
DataTypeTestUtils.numericTypeWithoutDecimal.foreach { tpe =>
checkConsistencyBetweenInterpretedAndCodegen(Abs, tpe)
}

checkEvaluation(Abs(Literal("-1.2")), 1.2)
checkEvaluation(Abs(Literal("-1")), 1.0)
checkEvaluation(Abs(Literal("1.11")), 1.11)
}

test("pmod") {
Expand Down
3 changes: 3 additions & 0 deletions sql/core/src/test/resources/sql-tests/inputs/operators.sql
Original file line number Diff line number Diff line change
Expand Up @@ -85,3 +85,6 @@ select mod(7, 2), mod(7, 0), mod(0, 2), mod(7, null), mod(null, 2), mod(null, nu
select BIT_LENGTH('abc');
select CHAR_LENGTH('abc');
select OCTET_LENGTH('abc');

-- abs
select abs(-3.13), abs('-2.19');
10 changes: 9 additions & 1 deletion sql/core/src/test/resources/sql-tests/results/operators.sql.out
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
-- Automatically generated by SQLQueryTestSuite
-- Number of queries: 54
-- Number of queries: 55


-- !query 0
Expand Down Expand Up @@ -444,3 +444,11 @@ select OCTET_LENGTH('abc')
struct<octetlength(abc):int>
-- !query 53 output
3


-- !query 54
select abs(-3.13), abs('-2.19')
-- !query 54 schema
struct<abs(-3.13):decimal(3,2),abs(CAST(-2.19 AS DOUBLE)):double>
-- !query 54 output
3.13 2.19

0 comments on commit afdb40d

Please sign in to comment.