diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala index 1f217390518a6..6082c58e2c53a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala @@ -357,6 +357,7 @@ object TypeCoercion { val commonType = findCommonTypeForBinaryComparison(left.dataType, right.dataType).get p.makeCopy(Array(castExpr(left, commonType), castExpr(right, commonType))) + case Abs(e @ StringType()) => Abs(Cast(e, DoubleType)) case Sum(e @ StringType()) => Sum(Cast(e, DoubleType)) case Average(e @ StringType()) => Average(Cast(e, DoubleType)) case StddevPop(e @ StringType()) => StddevPop(Cast(e, DoubleType)) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala index 9a338469d6f93..ec6e6ba0f091b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala @@ -22,7 +22,7 @@ import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.util.TypeUtils import org.apache.spark.sql.types._ -import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String} +import org.apache.spark.unsafe.types.CalendarInterval @ExpressionDescription( usage = "_FUNC_(expr) - Returns the negated value of `expr`.", @@ -97,30 +97,20 @@ case class UnaryPositive(child: Expression) case class Abs(child: Expression) extends UnaryExpression with ExpectsInputTypes with NullIntolerant { - override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection(NumericType, StringType)) + override def inputTypes: Seq[AbstractDataType] = Seq(NumericType) - override def dataType: DataType = child.dataType match { - case dt: NumericType => dt - case dt: StringType => DoubleType - } + override def dataType: DataType = child.dataType private lazy val numeric = TypeUtils.getNumeric(dataType) - override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = child.dataType match { + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = dataType match { case dt: DecimalType => defineCodeGen(ctx, ev, c => s"$c.abs()") case dt: NumericType => defineCodeGen(ctx, ev, c => s"(${ctx.javaType(dt)})(java.lang.Math.abs($c))") - case dt: StringType => - defineCodeGen(ctx, ev, c => s"java.lang.Math.abs(Double.valueOf($c.toString()))") } - protected override def nullSafeEval(input: Any): Any = child.dataType match { - case StringType => - numeric.abs(input.asInstanceOf[UTF8String].toString.toDouble) - case _: NumericType => - numeric.abs(input) - } + protected override def nullSafeEval(input: Any): Any = numeric.abs(input) } abstract class BinaryArithmetic extends BinaryOperator with NullIntolerant { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala index 2239bf815de71..744057b7c5f4c 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala @@ -57,6 +57,7 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite { test("check types for unary arithmetic") { assertError(UnaryMinus('stringField), "(numeric or calendarinterval) type") + assertError(Abs('stringField), "requires numeric type") assertError(BitwiseNot('stringField), "requires integral type") } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala index dffaf594683ff..0d86efda7ea86 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala @@ -205,10 +205,6 @@ class ArithmeticExpressionSuite extends SparkFunSuite with ExpressionEvalHelper DataTypeTestUtils.numericTypeWithoutDecimal.foreach { tpe => checkConsistencyBetweenInterpretedAndCodegen(Abs, tpe) } - - checkEvaluation(Abs(Literal("-1.2")), 1.2) - checkEvaluation(Abs(Literal("-1")), 1.0) - checkEvaluation(Abs(Literal("1.11")), 1.11) } test("pmod") { diff --git a/sql/core/src/test/resources/sql-tests/inputs/operators.sql b/sql/core/src/test/resources/sql-tests/inputs/operators.sql index a8de23e73892c..a1e8a32ed8f66 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/operators.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/operators.sql @@ -85,3 +85,6 @@ select mod(7, 2), mod(7, 0), mod(0, 2), mod(7, null), mod(null, 2), mod(null, nu select BIT_LENGTH('abc'); select CHAR_LENGTH('abc'); select OCTET_LENGTH('abc'); + +-- abs +select abs(-3.13), abs('-2.19'); diff --git a/sql/core/src/test/resources/sql-tests/results/operators.sql.out b/sql/core/src/test/resources/sql-tests/results/operators.sql.out index 85ee10b4d274f..eac3080bec67d 100644 --- a/sql/core/src/test/resources/sql-tests/results/operators.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/operators.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 54 +-- Number of queries: 55 -- !query 0 @@ -444,3 +444,11 @@ select OCTET_LENGTH('abc') struct -- !query 53 output 3 + + +-- !query 54 +select abs(-3.13), abs('-2.19') +-- !query 54 schema +struct +-- !query 54 output +3.13 2.19