diff --git a/core/src/main/java/com/amazon/opendistroforelasticsearch/sql/expression/DSL.java b/core/src/main/java/com/amazon/opendistroforelasticsearch/sql/expression/DSL.java index f3becbf922..e09b2a0a19 100644 --- a/core/src/main/java/com/amazon/opendistroforelasticsearch/sql/expression/DSL.java +++ b/core/src/main/java/com/amazon/opendistroforelasticsearch/sql/expression/DSL.java @@ -73,6 +73,14 @@ public FunctionExpression ceiling(Expression... expressions) { return function(BuiltinFunctionName.CEILING, expressions); } + public FunctionExpression conv(Expression... expressions) { + return function(BuiltinFunctionName.CONV, expressions); + } + + public FunctionExpression crc32(Expression... expressions) { + return function(BuiltinFunctionName.CRC32, expressions); + } + public FunctionExpression exp(Expression... expressions) { return function(BuiltinFunctionName.EXP, expressions); } @@ -97,6 +105,34 @@ public FunctionExpression log2(Expression... expressions) { return function(BuiltinFunctionName.LOG2, expressions); } + public FunctionExpression mod(Expression... expressions) { + return function(BuiltinFunctionName.MOD, expressions); + } + + public FunctionExpression pow(Expression... expressions) { + return function(BuiltinFunctionName.POW, expressions); + } + + public FunctionExpression power(Expression... expressions) { + return function(BuiltinFunctionName.POWER, expressions); + } + + public FunctionExpression round(Expression... expressions) { + return function(BuiltinFunctionName.ROUND, expressions); + } + + public FunctionExpression sign(Expression... expressions) { + return function(BuiltinFunctionName.SIGN, expressions); + } + + public FunctionExpression sqrt(Expression... expressions) { + return function(BuiltinFunctionName.SQRT, expressions); + } + + public FunctionExpression truncate(Expression... expressions) { + return function(BuiltinFunctionName.TRUNCATE, expressions); + } + public FunctionExpression add(Expression... expressions) { return function(BuiltinFunctionName.ADD, expressions); } diff --git a/core/src/main/java/com/amazon/opendistroforelasticsearch/sql/expression/function/BuiltinFunctionName.java b/core/src/main/java/com/amazon/opendistroforelasticsearch/sql/expression/function/BuiltinFunctionName.java index 9a754413fa..282cc0bb43 100644 --- a/core/src/main/java/com/amazon/opendistroforelasticsearch/sql/expression/function/BuiltinFunctionName.java +++ b/core/src/main/java/com/amazon/opendistroforelasticsearch/sql/expression/function/BuiltinFunctionName.java @@ -18,12 +18,21 @@ public enum BuiltinFunctionName { ABS(FunctionName.of("abs")), CEIL(FunctionName.of("ceil")), CEILING(FunctionName.of("ceiling")), + CONV(FunctionName.of("conv")), + CRC32(FunctionName.of("crc32")), EXP(FunctionName.of("exp")), FLOOR(FunctionName.of("floor")), LN(FunctionName.of("ln")), LOG(FunctionName.of("log")), LOG10(FunctionName.of("log10")), LOG2(FunctionName.of("log2")), + MOD(FunctionName.of("mod")), + POW(FunctionName.of("pow")), + POWER(FunctionName.of("power")), + ROUND(FunctionName.of("round")), + SIGN(FunctionName.of("sign")), + SQRT(FunctionName.of("sqrt")), + TRUNCATE(FunctionName.of("truncate")), /** * Text Functions. diff --git a/core/src/main/java/com/amazon/opendistroforelasticsearch/sql/expression/operator/OperatorUtils.java b/core/src/main/java/com/amazon/opendistroforelasticsearch/sql/expression/operator/OperatorUtils.java index 52216eb496..568cb6a8e3 100644 --- a/core/src/main/java/com/amazon/opendistroforelasticsearch/sql/expression/operator/OperatorUtils.java +++ b/core/src/main/java/com/amazon/opendistroforelasticsearch/sql/expression/operator/OperatorUtils.java @@ -34,6 +34,60 @@ @UtilityClass public class OperatorUtils { + /** + * Construct {@link FunctionBuilder} which call function with three arguments produced by + * observers.In general, if any operand evaluates to a MISSING value, the enclosing operator + * will return MISSING; if none of operands evaluates to a MISSING value but there is an + * operand evaluates to a NULL value, the enclosing operator will return NULL. + * + * @param functionName function name + * @param function {@link BiFunction} + * @param observer1 extract the value of type T from the first argument + * @param observer2 extract the value of type U from the first argument + * @param observer3 extract the value of type V from the first argument + * @param returnType return type + * @param the type of the first argument to the function + * @param the type of the second argument to the function + * @param the type of the third argument to the function + * @param the type of the result of the function + * @return {@link FunctionBuilder} + */ + public static FunctionBuilder tripleArgFunc( + FunctionName functionName, + TriFunction function, + Function observer1, + Function observer2, + Function observer3, + ExprCoreType returnType) { + return arguments -> new FunctionExpression(functionName, arguments) { + @Override + public ExprValue valueOf(Environment valueEnv) { + ExprValue arg1 = arguments.get(0).valueOf(valueEnv); + ExprValue arg2 = arguments.get(1).valueOf(valueEnv); + ExprValue arg3 = arguments.get(2).valueOf(valueEnv); + if (arg1.isMissing() || arg2.isMissing() || arg3.isMissing()) { + return ExprValueUtils.missingValue(); + } else if (arg1.isNull() || arg2.isNull() || arg3.isNull()) { + return ExprValueUtils.nullValue(); + } else { + return ExprValueUtils.fromObjectValue( + function.apply(observer1.apply(arg1), observer2.apply(arg2), observer3.apply(arg3))); + } + } + + @Override + public ExprType type() { + return returnType; + } + + @Override + public String toString() { + return String.format("%s(%s, %s, %s)", functionName, arguments.get(0).toString(), arguments + .get(1).toString(), arguments.get(2).toString()); + } + }; + } + /** * Construct {@link FunctionBuilder} which call function with arguments produced by observer. * @@ -222,4 +276,8 @@ public String toString() { */ public static final BiPredicate COMPARE_WITH_NULL_OR_MISSING = (left, right) -> left.isMissing() || right.isMissing() || left.isNull() || right.isNull(); + + public interface TriFunction { + R apply(T t, U u, V v); + } } diff --git a/core/src/main/java/com/amazon/opendistroforelasticsearch/sql/expression/operator/arthmetic/ArithmeticFunction.java b/core/src/main/java/com/amazon/opendistroforelasticsearch/sql/expression/operator/arthmetic/ArithmeticFunction.java index 300b89065d..cf296ef056 100644 --- a/core/src/main/java/com/amazon/opendistroforelasticsearch/sql/expression/operator/arthmetic/ArithmeticFunction.java +++ b/core/src/main/java/com/amazon/opendistroforelasticsearch/sql/expression/operator/arthmetic/ArithmeticFunction.java @@ -94,10 +94,10 @@ private static FunctionResolver divide() { return new FunctionResolver( BuiltinFunctionName.DIVIDE.getName(), scalarFunction(BuiltinFunctionName.DIVIDE.getName(), - (v1, v2) -> v1 / v2, - (v1, v2) -> v1 / v2, - (v1, v2) -> v1 / v2, - (v1, v2) -> v1 / v2) + (v1, v2) -> v2 == 0 ? null : v1 / v2, + (v1, v2) -> v2 == 0 ? null : v1 / v2, + (v1, v2) -> v2 == 0 ? null : v1 / v2, + (v1, v2) -> v2 == 0 ? null : v1 / v2) ); } @@ -106,10 +106,10 @@ private static FunctionResolver modules() { return new FunctionResolver( BuiltinFunctionName.MODULES.getName(), scalarFunction(BuiltinFunctionName.MODULES.getName(), - (v1, v2) -> v1 % v2, - (v1, v2) -> v1 % v2, - (v1, v2) -> v1 % v2, - (v1, v2) -> v1 % v2) + (v1, v2) -> v2 == 0 ? null : v1 % v2, + (v1, v2) -> v2 == 0 ? null : v1 % v2, + (v1, v2) -> v2 == 0 ? null : v1 % v2, + (v1, v2) -> v2 == 0 ? null : v1 % v2) ); } diff --git a/core/src/main/java/com/amazon/opendistroforelasticsearch/sql/expression/operator/arthmetic/MathematicalFunction.java b/core/src/main/java/com/amazon/opendistroforelasticsearch/sql/expression/operator/arthmetic/MathematicalFunction.java index 962ddf15eb..0bed318e72 100644 --- a/core/src/main/java/com/amazon/opendistroforelasticsearch/sql/expression/operator/arthmetic/MathematicalFunction.java +++ b/core/src/main/java/com/amazon/opendistroforelasticsearch/sql/expression/operator/arthmetic/MathematicalFunction.java @@ -16,6 +16,7 @@ package com.amazon.opendistroforelasticsearch.sql.expression.operator.arthmetic; import static com.amazon.opendistroforelasticsearch.sql.expression.operator.OperatorUtils.doubleArgFunc; +import static com.amazon.opendistroforelasticsearch.sql.expression.operator.OperatorUtils.tripleArgFunc; import static com.amazon.opendistroforelasticsearch.sql.expression.operator.OperatorUtils.unaryOperator; import com.amazon.opendistroforelasticsearch.sql.data.model.ExprValueUtils; @@ -27,10 +28,13 @@ import com.amazon.opendistroforelasticsearch.sql.expression.function.FunctionResolver; import com.amazon.opendistroforelasticsearch.sql.expression.function.FunctionSignature; import com.google.common.collect.ImmutableMap; +import java.math.BigDecimal; +import java.math.RoundingMode; import java.util.Arrays; import java.util.Map; import java.util.function.BiFunction; import java.util.function.Function; +import java.util.zip.CRC32; import lombok.experimental.UtilityClass; @UtilityClass @@ -44,12 +48,21 @@ public static void register(BuiltinFunctionRepository repository) { repository.register(abs()); repository.register(ceil()); repository.register(ceiling()); + repository.register(conv()); + repository.register(crc32()); repository.register(exp()); repository.register(floor()); repository.register(ln()); repository.register(log()); repository.register(log10()); repository.register(log2()); + repository.register(mod()); + repository.register(pow()); + repository.register(power()); + repository.register(round()); + repository.register(sign()); + repository.register(sqrt()); + repository.register(truncate()); } /** @@ -97,6 +110,60 @@ private static FunctionResolver ceiling() { .build()); } + /** + * Definition of conv(x, a, b) function. + * Convert number x from base a to base b + * The supported signature of floor function is + * (STRING, INTEGER, INTEGER) -> STRING + */ + private static FunctionResolver conv() { + FunctionName functionName = BuiltinFunctionName.CONV.getName(); + return new FunctionResolver( + functionName, + new ImmutableMap.Builder() + .put( + new FunctionSignature(functionName, + Arrays.asList(ExprCoreType.STRING, ExprCoreType.INTEGER, ExprCoreType.INTEGER)), + tripleArgFunc(functionName, + (num, fromBase, toBase) -> Integer.toString( + Integer.parseInt(num, fromBase), toBase), + ExprValueUtils::getStringValue, ExprValueUtils::getIntegerValue, + ExprValueUtils::getIntegerValue, ExprCoreType.STRING)) + .put( + new FunctionSignature(functionName, + Arrays.asList( + ExprCoreType.INTEGER, ExprCoreType.INTEGER, ExprCoreType.INTEGER)), + tripleArgFunc(functionName, + (num, fromBase, toBase) -> Integer.toString( + Integer.parseInt(num.toString(), fromBase), toBase), + ExprValueUtils::getIntegerValue, ExprValueUtils::getIntegerValue, + ExprValueUtils::getIntegerValue, ExprCoreType.STRING)) + .build()); + } + + /** + * Definition of crc32(x) function. + * Calculate a cyclic redundancy check value and returns a 32-bit unsigned value + * The supported signature of crc32 function is + * STRING -> LONG + */ + private static FunctionResolver crc32() { + FunctionName functionName = BuiltinFunctionName.CRC32.getName(); + return new FunctionResolver(functionName, + new ImmutableMap.Builder() + .put( + new FunctionSignature(functionName, Arrays.asList(ExprCoreType.STRING)), + unaryOperator( + functionName, + v -> { + CRC32 crc = new CRC32(); + crc.update(v.getBytes()); + return crc.getValue(); + }, + ExprValueUtils::getStringValue, ExprCoreType.LONG)) + .build()); + } + /** * Definition of exp(x) function. Calculate exponent function e to the x The supported signature * of exp function is INTEGER/LONG/FLOAT/DOUBLE -> DOUBLE @@ -182,6 +249,186 @@ private static FunctionResolver log2() { singleArgumentFunction(BuiltinFunctionName.LOG2.getName(), v -> Math.log(v) / Math.log(2))); } + /** + * Definition of mod(x, y) function. + * Calculate the remainder of x divided by y + * The supported signature of mod function is + * (x: INTEGER/LONG/FLOAT/DOUBLE, y: INTEGER/LONG/FLOAT/DOUBLE) + * -> wider type between types of x and y + */ + private static FunctionResolver mod() { + return new FunctionResolver( + BuiltinFunctionName.MOD.getName(), + doubleArgumentsFunction(BuiltinFunctionName.MOD.getName(), + (v1, v2) -> v2 == 0 ? null : v1 % v2, + (v1, v2) -> v2 == 0 ? null : v1 % v2, + (v1, v2) -> v2 == 0 ? null : v1 % v2, + (v1, v2) -> v2 == 0 ? null : v1 % v2)); + } + + /** + * Definition of pow(x, y)/power(x, y) function. + * Calculate the value of x raised to the power of y + * The supported signature of pow/power function is + * (INTEGER, INTEGER) -> INTEGER + * (LONG, LONG) -> LONG + * (FLOAT, FLOAT) -> FLOAT + * (DOUBLE, DOUBLE) -> DOUBLE + */ + private static FunctionResolver pow() { + FunctionName functionName = BuiltinFunctionName.POW.getName(); + return new FunctionResolver(functionName, doubleArgumentsFunction(functionName, Math::pow)); + } + + private static FunctionResolver power() { + FunctionName functionName = BuiltinFunctionName.POWER.getName(); + return new FunctionResolver(functionName, doubleArgumentsFunction(functionName, Math::pow)); + } + + /** + * Definition of round(x)/round(x, d) function. + * Rounds the argument x to d decimal places, d defaults to 0 if not specified. + * The supported signature of round function is + * (x: INTEGER [, y: INTEGER]) -> INTEGER + * (x: LONG [, y: INTEGER]) -> LONG + * (x: FLOAT [, y: INTEGER]) -> FLOAT + * (x: DOUBLE [, y: INTEGER]) -> DOUBLE + */ + private static FunctionResolver round() { + FunctionName functionName = BuiltinFunctionName.ROUND.getName(); + return new FunctionResolver(functionName, + new ImmutableMap.Builder() + .put( + new FunctionSignature(functionName, Arrays.asList(ExprCoreType.INTEGER)), + unaryOperator( + functionName, v -> (long) Math.round(v), ExprValueUtils::getIntegerValue, + ExprCoreType.LONG)) + .put( + new FunctionSignature(functionName, Arrays.asList(ExprCoreType.LONG)), + unaryOperator( + functionName, v -> (long) Math.round(v), ExprValueUtils::getLongValue, + ExprCoreType.LONG)) + .put( + new FunctionSignature(functionName, Arrays.asList(ExprCoreType.FLOAT)), + unaryOperator( + functionName, v -> (double) Math.round(v), ExprValueUtils::getFloatValue, + ExprCoreType.DOUBLE)) + .put( + new FunctionSignature(functionName, Arrays.asList(ExprCoreType.DOUBLE)), + unaryOperator( + functionName, v -> (double) Math.round(v), ExprValueUtils::getDoubleValue, + ExprCoreType.DOUBLE)) + .put( + new FunctionSignature( + functionName, Arrays.asList(ExprCoreType.INTEGER, ExprCoreType.INTEGER)), + doubleArgFunc(functionName, + (v1, v2) -> new BigDecimal(v1).setScale(v2, RoundingMode.HALF_UP).longValue(), + ExprValueUtils::getIntegerValue, ExprValueUtils::getIntegerValue, + ExprCoreType.LONG)) + .put( + new FunctionSignature( + functionName, Arrays.asList(ExprCoreType.LONG, ExprCoreType.INTEGER)), + doubleArgFunc(functionName, + (v1, v2) -> new BigDecimal(v1).setScale(v2, RoundingMode.HALF_UP).longValue(), + ExprValueUtils::getLongValue, ExprValueUtils::getIntegerValue, + ExprCoreType.LONG)) + .put( + new FunctionSignature( + functionName, Arrays.asList(ExprCoreType.FLOAT, ExprCoreType.INTEGER)), + doubleArgFunc(functionName, + (v1, v2) -> new BigDecimal(v1).setScale(v2, RoundingMode.HALF_UP).doubleValue(), + ExprValueUtils::getFloatValue, ExprValueUtils::getIntegerValue, + ExprCoreType.DOUBLE)) + .put( + new FunctionSignature( + functionName, Arrays.asList(ExprCoreType.DOUBLE, ExprCoreType.INTEGER)), + doubleArgFunc(functionName, + (v1, v2) -> new BigDecimal(v1).setScale(v2, RoundingMode.HALF_UP).doubleValue(), + ExprValueUtils::getDoubleValue, ExprValueUtils::getIntegerValue, + ExprCoreType.DOUBLE)) + .build()); + } + + /** + * Definition of sign(x) function. + * Returns the sign of the argument as -1, 0, or 1 + * depending on whether x is negative, zero, or positive + * The supported signature is + * INTEGER/LONG/FLOAT/DOUBLE -> INTEGER + */ + private static FunctionResolver sign() { + FunctionName functionName = BuiltinFunctionName.SIGN.getName(); + return new FunctionResolver( + functionName, + new ImmutableMap.Builder() + .put( + new FunctionSignature( + functionName, Arrays.asList(ExprCoreType.DOUBLE)), + unaryOperator( + functionName, v -> (int) Math.signum(v), ExprValueUtils::getDoubleValue, + ExprCoreType.INTEGER)) + .build()); + } + + /** + * Definition of sqrt(x) function. + * Calculate the square root of a non-negative number x + * The supported signature is + * INTEGER/LONG/FLOAT/DOUBLE -> DOUBLE + */ + private static FunctionResolver sqrt() { + FunctionName functionName = BuiltinFunctionName.SQRT.getName(); + return new FunctionResolver( + functionName, + singleArgumentFunction( + functionName, + v -> v < 0 ? null : Math.sqrt(v))); + } + + /** + * Definition of truncate(x, d) function. + * Returns the number x, truncated to d decimal places + * The supported signature of round function is + * (x: INTEGER, y: INTEGER) -> INTEGER + * (x: LONG, y: INTEGER) -> LONG + * (x: FLOAT, y: INTEGER) -> FLOAT + * (x: DOUBLE, y: INTEGER) -> DOUBLE + */ + private static FunctionResolver truncate() { + FunctionName functionName = BuiltinFunctionName.TRUNCATE.getName(); + return new FunctionResolver(functionName, + new ImmutableMap.Builder() + .put( + new FunctionSignature( + functionName, Arrays.asList(ExprCoreType.INTEGER, ExprCoreType.INTEGER)), + doubleArgFunc(functionName, + (v1, v2) -> new BigDecimal(v1).setScale(v2, RoundingMode.DOWN).longValue(), + ExprValueUtils::getIntegerValue, ExprValueUtils::getIntegerValue, + ExprCoreType.LONG)) + .put( + new FunctionSignature( + functionName, Arrays.asList(ExprCoreType.LONG, ExprCoreType.INTEGER)), + doubleArgFunc(functionName, + (v1, v2) -> new BigDecimal(v1).setScale(v2, RoundingMode.DOWN).longValue(), + ExprValueUtils::getLongValue, ExprValueUtils::getIntegerValue, + ExprCoreType.LONG)) + .put( + new FunctionSignature( + functionName, Arrays.asList(ExprCoreType.FLOAT, ExprCoreType.INTEGER)), + doubleArgFunc(functionName, + (v1, v2) -> new BigDecimal(v1).setScale(v2, RoundingMode.DOWN).doubleValue(), + ExprValueUtils::getFloatValue, ExprValueUtils::getIntegerValue, + ExprCoreType.DOUBLE)) + .put( + new FunctionSignature( + functionName, Arrays.asList(ExprCoreType.DOUBLE, ExprCoreType.INTEGER)), + doubleArgFunc(functionName, + (v1, v2) -> new BigDecimal(v1).setScale(v2, RoundingMode.DOWN).doubleValue(), + ExprValueUtils::getDoubleValue, ExprValueUtils::getIntegerValue, + ExprCoreType.DOUBLE)) + .build()); + } + /** * Util method to generate single argument function bundles. Applicable for INTEGER -> INTEGER * LONG -> LONG FLOAT -> FLOAT DOUBLE -> DOUBLE @@ -221,4 +468,50 @@ private static Map singleArgumentFunction( functionName, doubleFunc, ExprValueUtils::getDoubleValue, ExprCoreType.DOUBLE)) .build(); } + + private static Map doubleArgumentsFunction( + FunctionName functionName, + BiFunction intFunc, + BiFunction longFunc, + BiFunction floatFunc, + BiFunction doubleFunc) { + return new ImmutableMap.Builder() + .put( + new FunctionSignature( + functionName, Arrays.asList(ExprCoreType.INTEGER, ExprCoreType.INTEGER)), + doubleArgFunc( + functionName, intFunc, ExprValueUtils::getIntegerValue, + ExprValueUtils::getIntegerValue, ExprCoreType.INTEGER)) + .put( + new FunctionSignature( + functionName, Arrays.asList(ExprCoreType.LONG, ExprCoreType.LONG)), + doubleArgFunc( + functionName, longFunc, ExprValueUtils::getLongValue, + ExprValueUtils::getLongValue, ExprCoreType.LONG)) + .put( + new FunctionSignature( + functionName, Arrays.asList(ExprCoreType.FLOAT, ExprCoreType.FLOAT)), + doubleArgFunc( + functionName, floatFunc, ExprValueUtils::getFloatValue, + ExprValueUtils::getFloatValue, ExprCoreType.FLOAT)) + .put( + new FunctionSignature( + functionName, Arrays.asList(ExprCoreType.DOUBLE, ExprCoreType.DOUBLE)), + doubleArgFunc( + functionName, doubleFunc, ExprValueUtils::getDoubleValue, + ExprValueUtils::getDoubleValue, ExprCoreType.DOUBLE)) + .build(); + } + + private static Map doubleArgumentsFunction( + FunctionName functionName, + BiFunction doubleFunc) { + return new ImmutableMap.Builder() + .put( + new FunctionSignature( + functionName, Arrays.asList(ExprCoreType.DOUBLE, ExprCoreType.DOUBLE)), + doubleArgFunc( + functionName, doubleFunc, ExprValueUtils::getDoubleValue, + ExprValueUtils::getDoubleValue, ExprCoreType.DOUBLE)).build(); + } } diff --git a/core/src/test/java/com/amazon/opendistroforelasticsearch/sql/expression/operator/arthmetic/ArithmeticFunctionTest.java b/core/src/test/java/com/amazon/opendistroforelasticsearch/sql/expression/operator/arthmetic/ArithmeticFunctionTest.java index 002d2f1b19..877042ef32 100644 --- a/core/src/test/java/com/amazon/opendistroforelasticsearch/sql/expression/operator/arthmetic/ArithmeticFunctionTest.java +++ b/core/src/test/java/com/amazon/opendistroforelasticsearch/sql/expression/operator/arthmetic/ArithmeticFunctionTest.java @@ -24,6 +24,7 @@ import static com.amazon.opendistroforelasticsearch.sql.expression.DSL.literal; import static com.amazon.opendistroforelasticsearch.sql.expression.DSL.ref; import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; import com.amazon.opendistroforelasticsearch.sql.data.model.ExprValue; import com.amazon.opendistroforelasticsearch.sql.data.model.ExprValueUtils; @@ -169,6 +170,12 @@ public void divide(ExprValue op1, ExprValue op2) { assertEquals(expectedType, expression.type()); assertValueEqual(BuiltinFunctionName.DIVIDE, expectedType, op1, op2, expression.valueOf(null)); assertEquals(String.format("%s / %s", op1.toString(), op2.toString()), expression.toString()); + + expression = dsl.divide(literal(op1), literal(0)); + expectedType = WideningTypeRule.max(op1.type(), INTEGER); + assertEquals(expectedType, expression.type()); + assertTrue(expression.valueOf(valueEnv()).isNull()); + assertEquals(String.format("%s / 0", op1.toString()), expression.toString()); } @ParameterizedTest(name = "module({1}, {2})") @@ -179,6 +186,12 @@ public void module(ExprValue op1, ExprValue op2) { assertEquals(expectedType, expression.type()); assertValueEqual(BuiltinFunctionName.MODULES, expectedType, op1, op2, expression.valueOf(null)); assertEquals(op1.toString() + " % " + op2.toString(), expression.toString()); + + expression = dsl.module(literal(op1), literal(0)); + expectedType = WideningTypeRule.max(op1.type(), INTEGER); + assertEquals(expectedType, expression.type()); + assertTrue(expression.valueOf(valueEnv()).isNull()); + assertEquals(op1.toString() + " % 0", expression.toString()); } protected void assertValueEqual(BuiltinFunctionName builtinFunctionName, ExprType type, diff --git a/core/src/test/java/com/amazon/opendistroforelasticsearch/sql/expression/operator/arthmetic/MathematicalFunctionTest.java b/core/src/test/java/com/amazon/opendistroforelasticsearch/sql/expression/operator/arthmetic/MathematicalFunctionTest.java index 57e6fa5c7b..c4ada14b0d 100644 --- a/core/src/test/java/com/amazon/opendistroforelasticsearch/sql/expression/operator/arthmetic/MathematicalFunctionTest.java +++ b/core/src/test/java/com/amazon/opendistroforelasticsearch/sql/expression/operator/arthmetic/MathematicalFunctionTest.java @@ -19,11 +19,14 @@ import static com.amazon.opendistroforelasticsearch.sql.config.TestConfig.DOUBLE_TYPE_NULL_VALUE_FIELD; import static com.amazon.opendistroforelasticsearch.sql.config.TestConfig.INT_TYPE_MISSING_VALUE_FIELD; import static com.amazon.opendistroforelasticsearch.sql.config.TestConfig.INT_TYPE_NULL_VALUE_FIELD; +import static com.amazon.opendistroforelasticsearch.sql.config.TestConfig.STRING_TYPE_MISSING_VALUE_FILED; +import static com.amazon.opendistroforelasticsearch.sql.config.TestConfig.STRING_TYPE_NULL_VALUE_FILED; import static com.amazon.opendistroforelasticsearch.sql.data.model.ExprValueUtils.getDoubleValue; import static com.amazon.opendistroforelasticsearch.sql.data.type.ExprCoreType.DOUBLE; import static com.amazon.opendistroforelasticsearch.sql.data.type.ExprCoreType.FLOAT; import static com.amazon.opendistroforelasticsearch.sql.data.type.ExprCoreType.INTEGER; import static com.amazon.opendistroforelasticsearch.sql.data.type.ExprCoreType.LONG; +import static com.amazon.opendistroforelasticsearch.sql.data.type.ExprCoreType.STRING; import static com.amazon.opendistroforelasticsearch.sql.utils.MatcherUtils.hasType; import static com.amazon.opendistroforelasticsearch.sql.utils.MatcherUtils.hasValue; import static org.hamcrest.MatcherAssert.assertThat; @@ -32,10 +35,18 @@ import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertTrue; +import com.amazon.opendistroforelasticsearch.sql.data.model.ExprValue; +import com.amazon.opendistroforelasticsearch.sql.data.model.ExprValueUtils; import com.amazon.opendistroforelasticsearch.sql.expression.DSL; import com.amazon.opendistroforelasticsearch.sql.expression.ExpressionTestBase; import com.amazon.opendistroforelasticsearch.sql.expression.FunctionExpression; +import com.google.common.collect.Lists; +import java.math.BigDecimal; +import java.math.RoundingMode; +import java.util.List; +import java.util.stream.Collectors; import java.util.stream.Stream; +import java.util.zip.CRC32; import org.junit.jupiter.api.DisplayNameGeneration; import org.junit.jupiter.api.DisplayNameGenerator; import org.junit.jupiter.api.Test; @@ -226,6 +237,194 @@ public void ceil_missing_value() { assertTrue(ceiling.valueOf(valueEnv()).isMissing()); } + /** + * Test conv from decimal base with string as a number. + */ + @ParameterizedTest(name = "conv({0})") + @ValueSource(strings = {"1", "0", "-1"}) + public void conv_from_decimal(String value) { + FunctionExpression conv = dsl.conv(DSL.literal(value), DSL.literal(10), DSL.literal(2)); + assertThat( + conv.valueOf(valueEnv()), + allOf(hasType(STRING), hasValue(Integer.toString(Integer.parseInt(value), 2)))); + assertEquals(String.format("conv(\"%s\", 10, 2)", value), conv.toString()); + + conv = dsl.conv(DSL.literal(value), DSL.literal(10), DSL.literal(8)); + assertThat( + conv.valueOf(valueEnv()), + allOf(hasType(STRING), hasValue(Integer.toString(Integer.parseInt(value), 8)))); + assertEquals(String.format("conv(\"%s\", 10, 8)", value), conv.toString()); + + conv = dsl.conv(DSL.literal(value), DSL.literal(10), DSL.literal(16)); + assertThat( + conv.valueOf(valueEnv()), + allOf(hasType(STRING), hasValue(Integer.toString(Integer.parseInt(value), 16)))); + assertEquals(String.format("conv(\"%s\", 10, 16)", value), conv.toString()); + } + + /** + * Test conv from decimal base with integer as a number. + */ + @ParameterizedTest(name = "conv({0})") + @ValueSource(ints = {1, 0, -1}) + public void conv_from_decimal(Integer value) { + FunctionExpression conv = dsl.conv(DSL.literal(value), DSL.literal(10), DSL.literal(2)); + assertThat( + conv.valueOf(valueEnv()), + allOf(hasType(STRING), hasValue(Integer.toString(value, 2)))); + assertEquals(String.format("conv(%s, 10, 2)", value), conv.toString()); + + conv = dsl.conv(DSL.literal(value), DSL.literal(10), DSL.literal(8)); + assertThat( + conv.valueOf(valueEnv()), + allOf(hasType(STRING), hasValue(Integer.toString(value, 8)))); + assertEquals(String.format("conv(%s, 10, 8)", value), conv.toString()); + + conv = dsl.conv(DSL.literal(value), DSL.literal(10), DSL.literal(16)); + assertThat( + conv.valueOf(valueEnv()), + allOf(hasType(STRING), hasValue(Integer.toString(value, 16)))); + assertEquals(String.format("conv(%s, 10, 16)", value), conv.toString()); + } + + /** + * Test conv to decimal base with string as a number. + */ + @ParameterizedTest(name = "conv({0})") + @ValueSource(strings = {"11", "0", "11111"}) + public void conv_to_decimal(String value) { + FunctionExpression conv = dsl.conv(DSL.literal(value), DSL.literal(2), DSL.literal(10)); + assertThat( + conv.valueOf(valueEnv()), + allOf(hasType(STRING), hasValue(Integer.toString(Integer.parseInt(value, 2))))); + assertEquals(String.format("conv(\"%s\", 2, 10)", value), conv.toString()); + + conv = dsl.conv(DSL.literal(value), DSL.literal(8), DSL.literal(10)); + assertThat( + conv.valueOf(valueEnv()), + allOf(hasType(STRING), hasValue(Integer.toString(Integer.parseInt(value, 8))))); + assertEquals(String.format("conv(\"%s\", 8, 10)", value), conv.toString()); + + conv = dsl.conv(DSL.literal(value), DSL.literal(16), DSL.literal(10)); + assertThat( + conv.valueOf(valueEnv()), + allOf(hasType(STRING), hasValue(Integer.toString(Integer.parseInt(value, 16))))); + assertEquals(String.format("conv(\"%s\", 16, 10)", value), conv.toString()); + } + + /** + * Test conv to decimal base with integer as a number. + */ + @ParameterizedTest(name = "conv({0})") + @ValueSource(ints = {11, 0, 11111}) + public void conv_to_decimal(Integer value) { + FunctionExpression conv = dsl.conv(DSL.literal(value), DSL.literal(2), DSL.literal(10)); + assertThat( + conv.valueOf(valueEnv()), + allOf(hasType(STRING), hasValue(Integer.toString(Integer.parseInt(value.toString(), 2))))); + assertEquals(String.format("conv(%s, 2, 10)", value), conv.toString()); + + conv = dsl.conv(DSL.literal(value), DSL.literal(8), DSL.literal(10)); + assertThat( + conv.valueOf(valueEnv()), + allOf(hasType(STRING), hasValue(Integer.toString(Integer.parseInt(value.toString(), 8))))); + assertEquals(String.format("conv(%s, 8, 10)", value), conv.toString()); + + conv = dsl.conv(DSL.literal(value), DSL.literal(16), DSL.literal(10)); + assertThat( + conv.valueOf(valueEnv()), + allOf(hasType(STRING), hasValue(Integer.toString(Integer.parseInt(value.toString(), 16))))); + assertEquals(String.format("conv(%s, 16, 10)", value), conv.toString()); + } + + /** + * Test conv with null value. + */ + @Test + public void conv_null_value() { + FunctionExpression conv = dsl.conv( + DSL.ref(STRING_TYPE_NULL_VALUE_FILED, STRING), DSL.literal(10), DSL.literal(2)); + assertEquals(STRING, conv.type()); + assertTrue(conv.valueOf(valueEnv()).isNull()); + + conv = dsl.conv( + DSL.literal("1"), DSL.ref(INT_TYPE_NULL_VALUE_FIELD, INTEGER), DSL.literal(2)); + assertEquals(STRING, conv.type()); + assertTrue(conv.valueOf(valueEnv()).isNull()); + + conv = dsl.conv( + DSL.literal("1"), DSL.literal(10), DSL.ref(INT_TYPE_NULL_VALUE_FIELD, INTEGER)); + assertEquals(STRING, conv.type()); + assertTrue(conv.valueOf(valueEnv()).isNull()); + } + + /** + * Test conv with missing value. + */ + @Test + public void conv_missing_value() { + FunctionExpression conv = dsl.conv( + DSL.ref(STRING_TYPE_MISSING_VALUE_FILED, STRING), DSL.literal(10), DSL.literal(2)); + assertEquals(STRING, conv.type()); + assertTrue(conv.valueOf(valueEnv()).isMissing()); + + conv = dsl.conv( + DSL.literal("1"), DSL.ref(INT_TYPE_MISSING_VALUE_FIELD, INTEGER), DSL.literal(2)); + assertEquals(STRING, conv.type()); + assertTrue(conv.valueOf(valueEnv()).isMissing()); + + conv = dsl.conv( + DSL.literal("1"), DSL.literal(10), DSL.ref(INT_TYPE_MISSING_VALUE_FIELD, INTEGER)); + assertEquals(STRING, conv.type()); + assertTrue(conv.valueOf(valueEnv()).isMissing()); + } + + /** + * Test conv with null and missing values. + */ + @Test + public void conv_null_missing() { + FunctionExpression conv = dsl.conv(DSL.ref(STRING_TYPE_MISSING_VALUE_FILED, STRING), + DSL.ref(INT_TYPE_MISSING_VALUE_FIELD, INTEGER), DSL.literal(2)); + assertEquals(STRING, conv.type()); + assertTrue(conv.valueOf(valueEnv()).isMissing()); + } + + /** + * Test crc32 with string value. + */ + @ParameterizedTest(name = "crc({0})") + @ValueSource(strings = {"odfe", "sql"}) + public void crc32_string_value(String value) { + FunctionExpression crc = dsl.crc32(DSL.literal(value)); + CRC32 crc32 = new CRC32(); + crc32.update(value.getBytes()); + assertThat( + crc.valueOf(valueEnv()), + allOf(hasType(LONG), hasValue(crc32.getValue()))); + assertEquals(String.format("crc32(\"%s\")", value), crc.toString()); + } + + /** + * Test crc32 with null value. + */ + @Test + public void crc32_null_value() { + FunctionExpression crc = dsl.crc32(DSL.ref(STRING_TYPE_NULL_VALUE_FILED, STRING)); + assertEquals(LONG, crc.type()); + assertTrue(crc.valueOf(valueEnv()).isNull()); + } + + /** + * Test crc32 with missing value. + */ + @Test + public void crc32_missing_value() { + FunctionExpression crc = dsl.crc32(DSL.ref(STRING_TYPE_MISSING_VALUE_FILED, STRING)); + assertEquals(LONG, crc.type()); + assertTrue(crc.valueOf(valueEnv()).isMissing()); + } + /** * Test exp with integer value. */ @@ -803,4 +1002,719 @@ public void log2_missing_value() { assertEquals(DOUBLE, log.type()); assertTrue(log.valueOf(valueEnv()).isMissing()); } + + /** + * Test mod with integer value. + */ + @ParameterizedTest(name = "mod({0}, {1})") + @MethodSource("testLogIntegerArguments") + public void mod_int_value(Integer v1, Integer v2) { + FunctionExpression mod = dsl.mod(DSL.literal(v1), DSL.literal(v2)); + assertThat( + mod.valueOf(valueEnv()), + allOf(hasType(INTEGER), hasValue(v1 % v2))); + assertEquals(String.format("mod(%s, %s)", v1, v2), mod.toString()); + + mod = dsl.mod(DSL.literal(v1), DSL.literal(0)); + assertEquals(INTEGER, mod.type()); + assertTrue(mod.valueOf(valueEnv()).isNull()); + } + + /** + * Test mod with long value. + */ + @ParameterizedTest(name = "mod({0}, {1})") + @MethodSource("testLogLongArguments") + public void mod_long_value(Long v1, Long v2) { + FunctionExpression mod = dsl.mod(DSL.literal(v1), DSL.literal(v2)); + assertThat( + mod.valueOf(valueEnv()), + allOf(hasType(LONG), hasValue(v1 % v2))); + assertEquals(String.format("mod(%s, %s)", v1, v2), mod.toString()); + + mod = dsl.mod(DSL.literal(v1), DSL.literal(0)); + assertEquals(LONG, mod.type()); + assertTrue(mod.valueOf(valueEnv()).isNull()); + } + + /** + * Test mod with long value. + */ + @ParameterizedTest(name = "mod({0}, {1})") + @MethodSource("testLogFloatArguments") + public void mod_float_value(Float v1, Float v2) { + FunctionExpression mod = dsl.mod(DSL.literal(v1), DSL.literal(v2)); + assertThat( + mod.valueOf(valueEnv()), + allOf(hasType(FLOAT), hasValue(v1 % v2))); + assertEquals(String.format("mod(%s, %s)", v1, v2), mod.toString()); + + mod = dsl.mod(DSL.literal(v1), DSL.literal(0)); + assertEquals(FLOAT, mod.type()); + assertTrue(mod.valueOf(valueEnv()).isNull()); + } + + /** + * Test mod with double value. + */ + @ParameterizedTest(name = "mod({0}, {1})") + @MethodSource("testLogDoubleArguments") + public void mod_double_value(Double v1, Double v2) { + FunctionExpression mod = dsl.mod(DSL.literal(v1), DSL.literal(v2)); + assertThat( + mod.valueOf(valueEnv()), + allOf(hasType(DOUBLE), hasValue(v1 % v2))); + assertEquals(String.format("mod(%s, %s)", v1, v2), mod.toString()); + + mod = dsl.mod(DSL.literal(v1), DSL.literal(0)); + assertEquals(DOUBLE, mod.type()); + assertTrue(mod.valueOf(valueEnv()).isNull()); + } + + /** + * Test mod with null value. + */ + @Test + public void mod_null_value() { + FunctionExpression mod = dsl.mod(DSL.ref(INT_TYPE_NULL_VALUE_FIELD, INTEGER), DSL.literal(1)); + assertEquals(INTEGER, mod.type()); + assertTrue(mod.valueOf(valueEnv()).isNull()); + + mod = dsl.mod(DSL.literal(1), DSL.ref(INT_TYPE_NULL_VALUE_FIELD, INTEGER)); + assertEquals(INTEGER, mod.type()); + assertTrue(mod.valueOf(valueEnv()).isNull()); + + mod = dsl.mod( + DSL.ref(INT_TYPE_NULL_VALUE_FIELD, INTEGER), DSL.ref(INT_TYPE_NULL_VALUE_FIELD, INTEGER)); + assertEquals(INTEGER, mod.type()); + assertTrue(mod.valueOf(valueEnv()).isNull()); + } + + /** + * Test mod with missing value. + */ + @Test + public void mod_missing_value() { + FunctionExpression mod = + dsl.mod(DSL.ref(INT_TYPE_MISSING_VALUE_FIELD, INTEGER), DSL.literal(1)); + assertEquals(INTEGER, mod.type()); + assertTrue(mod.valueOf(valueEnv()).isMissing()); + + mod = dsl.mod(DSL.literal(1), DSL.ref(INT_TYPE_MISSING_VALUE_FIELD, INTEGER)); + assertEquals(INTEGER, mod.type()); + assertTrue(mod.valueOf(valueEnv()).isMissing()); + + mod = dsl.mod( + DSL.ref(INT_TYPE_MISSING_VALUE_FIELD, INTEGER), + DSL.ref(INT_TYPE_MISSING_VALUE_FIELD, INTEGER)); + assertEquals(INTEGER, mod.type()); + assertTrue(mod.valueOf(valueEnv()).isMissing()); + } + + /** + * Test mod with null and missing values. + */ + @Test + public void mod_null_missing() { + FunctionExpression mod = dsl.mod(DSL.ref(INT_TYPE_MISSING_VALUE_FIELD, INTEGER), + DSL.ref(INT_TYPE_NULL_VALUE_FIELD, INTEGER)); + assertEquals(INTEGER, mod.type()); + assertTrue(mod.valueOf(valueEnv()).isMissing()); + + mod = dsl.mod(DSL.ref(INT_TYPE_NULL_VALUE_FIELD, INTEGER), + DSL.ref(INT_TYPE_MISSING_VALUE_FIELD, INTEGER)); + assertEquals(INTEGER, mod.type()); + assertTrue(mod.valueOf(valueEnv()).isMissing()); + } + + /** + * Test pow/power with integer value. + */ + @ParameterizedTest(name = "pow({0}, {1}") + @MethodSource("testLogIntegerArguments") + public void pow_int_value(Integer v1, Integer v2) { + FunctionExpression pow = dsl.pow(DSL.literal(v1), DSL.literal(v2)); + assertThat( + pow.valueOf(valueEnv()), + allOf(hasType(DOUBLE), hasValue(Math.pow(v1, v2)))); + assertEquals(String.format("pow(%s, %s)", v1, v2), pow.toString()); + + FunctionExpression power = dsl.power(DSL.literal(v1), DSL.literal(v2)); + assertThat( + power.valueOf(valueEnv()), + allOf(hasType(DOUBLE), hasValue(Math.pow(v1, v2)))); + assertEquals(String.format("pow(%s, %s)", v1, v2), pow.toString()); + } + + /** + * Test pow/power with long value. + */ + @ParameterizedTest(name = "pow({0}, {1}") + @MethodSource("testLogLongArguments") + public void pow_long_value(Long v1, Long v2) { + FunctionExpression pow = dsl.pow(DSL.literal(v1), DSL.literal(v2)); + assertThat( + pow.valueOf(valueEnv()), + allOf(hasType(DOUBLE), hasValue(Math.pow(v1, v2)))); + assertEquals(String.format("pow(%s, %s)", v1, v2), pow.toString()); + + FunctionExpression power = dsl.power(DSL.literal(v1), DSL.literal(v2)); + assertThat( + power.valueOf(valueEnv()), + allOf(hasType(DOUBLE), hasValue(Math.pow(v1, v2)))); + assertEquals(String.format("pow(%s, %s)", v1, v2), pow.toString()); + } + + /** + * Test pow/power with float value. + */ + @ParameterizedTest(name = "pow({0}, {1}") + @MethodSource("testLogFloatArguments") + public void pow_float_value(Float v1, Float v2) { + FunctionExpression pow = dsl.pow(DSL.literal(v1), DSL.literal(v2)); + assertThat( + pow.valueOf(valueEnv()), + allOf(hasType(DOUBLE), hasValue(Math.pow(v1, v2)))); + assertEquals(String.format("pow(%s, %s)", v1, v2), pow.toString()); + + FunctionExpression power = dsl.power(DSL.literal(v1), DSL.literal(v2)); + assertThat( + power.valueOf(valueEnv()), + allOf(hasType(DOUBLE), hasValue(Math.pow(v1, v2)))); + assertEquals(String.format("pow(%s, %s)", v1, v2), pow.toString()); + } + + /** + * Test pow/power with double value. + */ + @ParameterizedTest(name = "pow({0}, {1}") + @MethodSource("testLogDoubleArguments") + public void pow_double_value(Double v1, Double v2) { + FunctionExpression pow = dsl.pow(DSL.literal(v1), DSL.literal(v2)); + assertThat( + pow.valueOf(valueEnv()), + allOf(hasType(DOUBLE), hasValue(Math.pow(v1, v2)))); + assertEquals(String.format("pow(%s, %s)", v1, v2), pow.toString()); + + FunctionExpression power = dsl.power(DSL.literal(v1), DSL.literal(v2)); + assertThat( + power.valueOf(valueEnv()), + allOf(hasType(DOUBLE), hasValue(Math.pow(v1, v2)))); + assertEquals(String.format("pow(%s, %s)", v1, v2), pow.toString()); + } + + /** + * Test pow/power with null value. + */ + @Test + public void pow_null_value() { + FunctionExpression pow = dsl.pow(DSL.ref(INT_TYPE_NULL_VALUE_FIELD, INTEGER), DSL.literal(1)); + assertEquals(DOUBLE, pow.type()); + assertTrue(pow.valueOf(valueEnv()).isNull()); + + dsl.pow(DSL.literal(1), DSL.ref(INT_TYPE_NULL_VALUE_FIELD, INTEGER)); + assertEquals(DOUBLE, pow.type()); + assertTrue(pow.valueOf(valueEnv()).isNull()); + + dsl.pow( + DSL.ref(INT_TYPE_NULL_VALUE_FIELD, INTEGER), DSL.ref(INT_TYPE_NULL_VALUE_FIELD, INTEGER)); + assertEquals(DOUBLE, pow.type()); + assertTrue(pow.valueOf(valueEnv()).isNull()); + + FunctionExpression power = + dsl.power(DSL.ref(INT_TYPE_NULL_VALUE_FIELD, INTEGER), DSL.literal(1)); + assertEquals(DOUBLE, power.type()); + assertTrue(power.valueOf(valueEnv()).isNull()); + + power = dsl.power(DSL.literal(1), DSL.ref(INT_TYPE_NULL_VALUE_FIELD, INTEGER)); + assertEquals(DOUBLE, power.type()); + assertTrue(power.valueOf(valueEnv()).isNull()); + + power = dsl.power( + DSL.ref(INT_TYPE_NULL_VALUE_FIELD, INTEGER), DSL.ref(INT_TYPE_NULL_VALUE_FIELD, INTEGER)); + assertEquals(DOUBLE, power.type()); + assertTrue(power.valueOf(valueEnv()).isNull()); + } + + /** + * Test pow/power with missing value. + */ + @Test + public void pow_missing_value() { + FunctionExpression pow = + dsl.pow(DSL.ref(INT_TYPE_MISSING_VALUE_FIELD, INTEGER), DSL.literal(1)); + assertEquals(DOUBLE, pow.type()); + assertTrue(pow.valueOf(valueEnv()).isMissing()); + + dsl.pow(DSL.literal(1), DSL.ref(INT_TYPE_MISSING_VALUE_FIELD, INTEGER)); + assertEquals(DOUBLE, pow.type()); + assertTrue(pow.valueOf(valueEnv()).isMissing()); + + dsl.pow(DSL.ref(INT_TYPE_MISSING_VALUE_FIELD, INTEGER), + DSL.ref(INT_TYPE_MISSING_VALUE_FIELD, INTEGER)); + assertEquals(DOUBLE, pow.type()); + assertTrue(pow.valueOf(valueEnv()).isMissing()); + + FunctionExpression power = + dsl.power(DSL.ref(INT_TYPE_MISSING_VALUE_FIELD, INTEGER), DSL.literal(1)); + assertEquals(DOUBLE, power.type()); + assertTrue(power.valueOf(valueEnv()).isMissing()); + + power = dsl.power(DSL.literal(1), DSL.ref(INT_TYPE_MISSING_VALUE_FIELD, INTEGER)); + assertEquals(DOUBLE, power.type()); + assertTrue(power.valueOf(valueEnv()).isMissing()); + + power = dsl.power(DSL.ref(INT_TYPE_MISSING_VALUE_FIELD, INTEGER), + DSL.ref(INT_TYPE_MISSING_VALUE_FIELD, INTEGER)); + assertEquals(DOUBLE, power.type()); + assertTrue(power.valueOf(valueEnv()).isMissing()); + } + + /** + * Test pow/power with null and missing values. + */ + @Test + public void pow_null_missing() { + FunctionExpression pow = dsl.pow( + DSL.ref(INT_TYPE_NULL_VALUE_FIELD, INTEGER), + DSL.ref(INT_TYPE_MISSING_VALUE_FIELD, INTEGER)); + assertEquals(DOUBLE, pow.type()); + assertTrue(pow.valueOf(valueEnv()).isMissing()); + + pow = dsl.pow( + DSL.ref(INT_TYPE_MISSING_VALUE_FIELD, INTEGER), + DSL.ref(INT_TYPE_NULL_VALUE_FIELD, INTEGER)); + assertEquals(DOUBLE, pow.type()); + assertTrue(pow.valueOf(valueEnv()).isMissing()); + + FunctionExpression power = dsl.power( + DSL.ref(INT_TYPE_NULL_VALUE_FIELD, INTEGER), + DSL.ref(INT_TYPE_MISSING_VALUE_FIELD, INTEGER)); + assertEquals(DOUBLE, power.type()); + assertTrue(power.valueOf(valueEnv()).isMissing()); + + power = dsl.power( + DSL.ref(INT_TYPE_MISSING_VALUE_FIELD, INTEGER), + DSL.ref(INT_TYPE_NULL_VALUE_FIELD, INTEGER)); + assertEquals(DOUBLE, power.type()); + assertTrue(power.valueOf(valueEnv()).isMissing()); + } + + /** + * Test round with integer value. + */ + @ParameterizedTest(name = "round({0}") + @ValueSource(ints = {21, -21}) + public void round_int_value(Integer value) { + FunctionExpression round = dsl.round(DSL.literal(value)); + assertThat( + round.valueOf(valueEnv()), + allOf(hasType(LONG), hasValue((long) Math.round(value)))); + assertEquals(String.format("round(%s)", value), round.toString()); + + round = dsl.round(DSL.literal(value), DSL.literal(1)); + assertThat( + round.valueOf(valueEnv()), + allOf(hasType(LONG), hasValue( + new BigDecimal(value).setScale(1, RoundingMode.HALF_UP).longValue()))); + assertEquals(String.format("round(%s, 1)", value), round.toString()); + + round = dsl.round(DSL.literal(value), DSL.literal(-1)); + assertThat( + round.valueOf(valueEnv()), + allOf(hasType(LONG), hasValue( + new BigDecimal(value).setScale(-1, RoundingMode.HALF_UP).longValue()))); + assertEquals(String.format("round(%s, -1)", value), round.toString()); + } + + /** + * Test round with long value. + */ + @ParameterizedTest(name = "round({0}") + @ValueSource(longs = {21L, -21L}) + public void round_long_value(Long value) { + FunctionExpression round = dsl.round(DSL.literal(value)); + assertThat( + round.valueOf(valueEnv()), + allOf(hasType(LONG), hasValue((long) Math.round(value)))); + assertEquals(String.format("round(%s)", value), round.toString()); + + round = dsl.round(DSL.literal(value), DSL.literal(1)); + assertThat( + round.valueOf(valueEnv()), + allOf(hasType(LONG), hasValue( + new BigDecimal(value).setScale(1, RoundingMode.HALF_UP).longValue()))); + assertEquals(String.format("round(%s, 1)", value), round.toString()); + + round = dsl.round(DSL.literal(value), DSL.literal(-1)); + assertThat( + round.valueOf(valueEnv()), + allOf(hasType(LONG), hasValue( + new BigDecimal(value).setScale(-1, RoundingMode.HALF_UP).longValue()))); + assertEquals(String.format("round(%s, -1)", value), round.toString()); + } + + /** + * Test round with float value. + */ + @ParameterizedTest(name = "round({0}") + @ValueSource(floats = {21F, -21F}) + public void round_float_value(Float value) { + FunctionExpression round = dsl.round(DSL.literal(value)); + assertThat( + round.valueOf(valueEnv()), + allOf(hasType(DOUBLE), hasValue((double) Math.round(value)))); + assertEquals(String.format("round(%s)", value), round.toString()); + + round = dsl.round(DSL.literal(value), DSL.literal(1)); + assertThat( + round.valueOf(valueEnv()), + allOf(hasType(DOUBLE), hasValue( + new BigDecimal(value).setScale(1, RoundingMode.HALF_UP).doubleValue()))); + assertEquals(String.format("round(%s, 1)", value), round.toString()); + + round = dsl.round(DSL.literal(value), DSL.literal(-1)); + assertThat( + round.valueOf(valueEnv()), + allOf(hasType(DOUBLE), hasValue( + new BigDecimal(value).setScale(-1, RoundingMode.HALF_UP).doubleValue()))); + assertEquals(String.format("round(%s, -1)", value), round.toString()); + } + + /** + * Test round with double value. + */ + @ParameterizedTest(name = "round({0}") + @ValueSource(doubles = {21D, -21D}) + public void round_double_value(Double value) { + FunctionExpression round = dsl.round(DSL.literal(value)); + assertThat( + round.valueOf(valueEnv()), + allOf(hasType(DOUBLE), hasValue((double) Math.round(value)))); + assertEquals(String.format("round(%s)", value), round.toString()); + + round = dsl.round(DSL.literal(value), DSL.literal(1)); + assertThat( + round.valueOf(valueEnv()), + allOf(hasType(DOUBLE), hasValue( + new BigDecimal(value).setScale(1, RoundingMode.HALF_UP).doubleValue()))); + assertEquals(String.format("round(%s, 1)", value), round.toString()); + + round = dsl.round(DSL.literal(value), DSL.literal(-1)); + assertThat( + round.valueOf(valueEnv()), + allOf(hasType(DOUBLE), hasValue( + new BigDecimal(value).setScale(-1, RoundingMode.HALF_UP).doubleValue()))); + assertEquals(String.format("round(%s, -1)", value), round.toString()); + } + + /** + * Test round with null value. + */ + @Test + public void round_null_value() { + FunctionExpression round = dsl.round(DSL.ref(INT_TYPE_NULL_VALUE_FIELD, INTEGER)); + assertEquals(LONG, round.type()); + assertTrue(round.valueOf(valueEnv()).isNull()); + + round = dsl.round(DSL.ref(INT_TYPE_NULL_VALUE_FIELD, INTEGER), DSL.literal(1)); + assertEquals(LONG, round.type()); + assertTrue(round.valueOf(valueEnv()).isNull()); + + round = dsl.round(DSL.literal(1), DSL.ref(INT_TYPE_NULL_VALUE_FIELD, INTEGER)); + assertEquals(LONG, round.type()); + assertTrue(round.valueOf(valueEnv()).isNull()); + } + + /** + * Test round with null value. + */ + @Test + public void round_missing_value() { + FunctionExpression round = dsl.round(DSL.ref(INT_TYPE_MISSING_VALUE_FIELD, INTEGER)); + assertEquals(LONG, round.type()); + assertTrue(round.valueOf(valueEnv()).isMissing()); + + round = dsl.round(DSL.ref(INT_TYPE_MISSING_VALUE_FIELD, INTEGER), DSL.literal(1)); + assertEquals(LONG, round.type()); + assertTrue(round.valueOf(valueEnv()).isMissing()); + + round = dsl.round(DSL.literal(1), DSL.ref(INT_TYPE_MISSING_VALUE_FIELD, INTEGER)); + assertEquals(LONG, round.type()); + assertTrue(round.valueOf(valueEnv()).isMissing()); + } + + /** + * Test round with null and missing values. + */ + @Test + public void round_null_missing() { + FunctionExpression round = dsl.round( + DSL.ref(INT_TYPE_NULL_VALUE_FIELD, INTEGER), + DSL.ref(INT_TYPE_MISSING_VALUE_FIELD, INTEGER)); + assertEquals(LONG, round.type()); + assertTrue(round.valueOf(valueEnv()).isMissing()); + + round = dsl.round( + DSL.ref(INT_TYPE_MISSING_VALUE_FIELD, INTEGER), + DSL.ref(INT_TYPE_NULL_VALUE_FIELD, INTEGER)); + assertEquals(LONG, round.type()); + assertTrue(round.valueOf(valueEnv()).isMissing()); + } + + /** + * Test sign with integer value. + */ + @ParameterizedTest(name = "sign({0})") + @ValueSource(ints = {2, -2}) + public void sign_int_value(Integer value) { + FunctionExpression sign = dsl.sign(DSL.literal(value)); + assertThat( + sign.valueOf(valueEnv()), + allOf(hasType(INTEGER), hasValue((int) Math.signum(value)))); + assertEquals(String.format("sign(%s)", value), sign.toString()); + } + + /** + * Test sign with long value. + */ + @ParameterizedTest(name = "sign({0})") + @ValueSource(longs = {2L, -2L}) + public void sign_long_value(Long value) { + FunctionExpression sign = dsl.sign(DSL.literal(value)); + assertThat( + sign.valueOf(valueEnv()), + allOf(hasType(INTEGER), hasValue((int) Math.signum(value)))); + assertEquals(String.format("sign(%s)", value), sign.toString()); + } + + /** + * Test sign with float value. + */ + @ParameterizedTest(name = "sign({0})") + @ValueSource(floats = {2F, -2F}) + public void sign_float_value(Float value) { + FunctionExpression sign = dsl.sign(DSL.literal(value)); + assertThat( + sign.valueOf(valueEnv()), + allOf(hasType(INTEGER), hasValue((int) Math.signum(value)))); + assertEquals(String.format("sign(%s)", value), sign.toString()); + } + + /** + * Test sign with double value. + */ + @ParameterizedTest(name = "sign({0})") + @ValueSource(doubles = {2, -2}) + public void sign_double_value(Double value) { + FunctionExpression sign = dsl.sign(DSL.literal(value)); + assertThat( + sign.valueOf(valueEnv()), + allOf(hasType(INTEGER), hasValue((int) Math.signum(value)))); + assertEquals(String.format("sign(%s)", value), sign.toString()); + } + + /** + * Test sign with null value. + */ + @Test + public void sign_null_value() { + FunctionExpression sign = dsl.sign(DSL.ref(INT_TYPE_NULL_VALUE_FIELD, INTEGER)); + assertEquals(INTEGER, sign.type()); + assertTrue(sign.valueOf(valueEnv()).isNull()); + } + + /** + * Test sign with missing value. + */ + @Test + public void sign_missing_value() { + FunctionExpression sign = dsl.sign(DSL.ref(INT_TYPE_MISSING_VALUE_FIELD, INTEGER)); + assertEquals(INTEGER, sign.type()); + assertTrue(sign.valueOf(valueEnv()).isMissing()); + } + + /** + * Test sqrt with int value. + */ + @ParameterizedTest(name = "sqrt({0})") + @ValueSource(ints = {1, 2}) + public void sqrt_int_value(Integer value) { + FunctionExpression sqrt = dsl.sqrt(DSL.literal(value)); + assertThat(sqrt.valueOf(valueEnv()), allOf(hasType(DOUBLE), hasValue(Math.sqrt(value)))); + assertEquals(String.format("sqrt(%s)", value), sqrt.toString()); + } + + /** + * Test sqrt with long value. + */ + @ParameterizedTest(name = "sqrt({0})") + @ValueSource(longs = {1L, 2L}) + public void sqrt_long_value(Long value) { + FunctionExpression sqrt = dsl.sqrt(DSL.literal(value)); + assertThat(sqrt.valueOf(valueEnv()), allOf(hasType(DOUBLE), hasValue(Math.sqrt(value)))); + assertEquals(String.format("sqrt(%s)", value), sqrt.toString()); + } + + /** + * Test sqrt with float value. + */ + @ParameterizedTest(name = "sqrt({0})") + @ValueSource(floats = {1F, 2F}) + public void sqrt_float_value(Float value) { + FunctionExpression sqrt = dsl.sqrt(DSL.literal(value)); + assertThat(sqrt.valueOf(valueEnv()), allOf(hasType(DOUBLE), hasValue(Math.sqrt(value)))); + assertEquals(String.format("sqrt(%s)", value), sqrt.toString()); + } + + /** + * Test sqrt with double value. + */ + @ParameterizedTest(name = "sqrt({0})") + @ValueSource(doubles = {1D, 2D}) + public void sqrt_double_value(Double value) { + FunctionExpression sqrt = dsl.sqrt(DSL.literal(value)); + assertThat(sqrt.valueOf(valueEnv()), allOf(hasType(DOUBLE), hasValue(Math.sqrt(value)))); + assertEquals(String.format("sqrt(%s)", value), sqrt.toString()); + } + + /** + * Test sqrt with negative value. + */ + @ParameterizedTest(name = "sqrt({0})") + @ValueSource(doubles = {-1D, -2D}) + public void sqrt_negative_value(Double value) { + FunctionExpression sqrt = dsl.sqrt(DSL.literal(value)); + assertEquals(DOUBLE, sqrt.type()); + assertTrue(sqrt.valueOf(valueEnv()).isNull()); + } + + /** + * Test sqrt with null value. + */ + @Test + public void sqrt_null_value() { + FunctionExpression sqrt = dsl.sqrt(DSL.ref(INT_TYPE_NULL_VALUE_FIELD, INTEGER)); + assertEquals(DOUBLE, sqrt.type()); + assertTrue(sqrt.valueOf(valueEnv()).isNull()); + } + + /** + * Test sqrt with missing value. + */ + @Test + public void sqrt_missing_value() { + FunctionExpression sqrt = dsl.sqrt(DSL.ref(INT_TYPE_MISSING_VALUE_FIELD, INTEGER)); + assertEquals(DOUBLE, sqrt.type()); + assertTrue(sqrt.valueOf(valueEnv()).isMissing()); + } + + /** + * Test truncate with integer value. + */ + @ParameterizedTest(name = "truncate({0}, {1})") + @ValueSource(ints = {2, -2}) + public void truncate_int_value(Integer value) { + FunctionExpression truncate = dsl.truncate(DSL.literal(value), DSL.literal(1)); + assertThat( + truncate.valueOf(valueEnv()), allOf(hasType(LONG), + hasValue(new BigDecimal(value).setScale(1, RoundingMode.DOWN).longValue()))); + assertEquals(String.format("truncate(%s, 1)", value), truncate.toString()); + } + + /** + * Test truncate with long value. + */ + @ParameterizedTest(name = "truncate({0}, {1})") + @ValueSource(longs = {2L, -2L}) + public void truncate_long_value(Long value) { + FunctionExpression truncate = dsl.truncate(DSL.literal(value), DSL.literal(1)); + assertThat( + truncate.valueOf(valueEnv()), allOf(hasType(LONG), + hasValue(new BigDecimal(value).setScale(1, RoundingMode.DOWN).longValue()))); + assertEquals(String.format("truncate(%s, 1)", value), truncate.toString()); + } + + /** + * Test truncate with float value. + */ + @ParameterizedTest(name = "truncate({0}, {1})") + @ValueSource(floats = {2F, -2F}) + public void truncate_float_value(Float value) { + FunctionExpression truncate = dsl.truncate(DSL.literal(value), DSL.literal(1)); + assertThat( + truncate.valueOf(valueEnv()), allOf(hasType(DOUBLE), + hasValue(new BigDecimal(value).setScale(1, RoundingMode.DOWN).doubleValue()))); + assertEquals(String.format("truncate(%s, 1)", value), truncate.toString()); + } + + /** + * Test truncate with double value. + */ + @ParameterizedTest(name = "truncate({0}, {1})") + @ValueSource(doubles = {2D, -2D}) + public void truncate_double_value(Double value) { + FunctionExpression truncate = dsl.truncate(DSL.literal(value), DSL.literal(1)); + assertThat( + truncate.valueOf(valueEnv()), allOf(hasType(DOUBLE), + hasValue(new BigDecimal(value).setScale(1, RoundingMode.DOWN).doubleValue()))); + assertEquals(String.format("truncate(%s, 1)", value), truncate.toString()); + } + + /** + * Test truncate with null value. + */ + @Test + public void truncate_null_value() { + FunctionExpression truncate = + dsl.truncate(DSL.ref(INT_TYPE_NULL_VALUE_FIELD, INTEGER), DSL.literal(1)); + assertEquals(LONG, truncate.type()); + assertTrue(truncate.valueOf(valueEnv()).isNull()); + + truncate = dsl.truncate(DSL.literal(1), DSL.ref(INT_TYPE_NULL_VALUE_FIELD, INTEGER)); + assertEquals(LONG, truncate.type()); + assertTrue(truncate.valueOf(valueEnv()).isNull()); + + truncate = dsl.truncate( + DSL.ref(INT_TYPE_NULL_VALUE_FIELD, INTEGER), DSL.ref(INT_TYPE_NULL_VALUE_FIELD, INTEGER)); + assertEquals(LONG, truncate.type()); + assertTrue(truncate.valueOf(valueEnv()).isNull()); + } + + /** + * Test truncate with missing value. + */ + @Test + public void truncate_missing_value() { + FunctionExpression truncate = + dsl.truncate(DSL.ref(INT_TYPE_MISSING_VALUE_FIELD, INTEGER), DSL.literal(1)); + assertEquals(LONG, truncate.type()); + assertTrue(truncate.valueOf(valueEnv()).isMissing()); + + truncate = dsl.truncate(DSL.literal(1), DSL.ref(INT_TYPE_MISSING_VALUE_FIELD, INTEGER)); + assertEquals(LONG, truncate.type()); + assertTrue(truncate.valueOf(valueEnv()).isMissing()); + + truncate = dsl.truncate( + DSL.ref(INT_TYPE_MISSING_VALUE_FIELD, INTEGER), + DSL.ref(INT_TYPE_MISSING_VALUE_FIELD, INTEGER)); + assertEquals(LONG, truncate.type()); + assertTrue(truncate.valueOf(valueEnv()).isMissing()); + } + + /** + * Test truncate with null and missing values. + */ + @Test + public void truncate_null_missing() { + FunctionExpression truncate = dsl.truncate(DSL.ref(INT_TYPE_NULL_VALUE_FIELD, INTEGER), + DSL.ref(INT_TYPE_MISSING_VALUE_FIELD, INTEGER)); + assertEquals(LONG, truncate.type()); + assertTrue(truncate.valueOf(valueEnv()).isMissing()); + + truncate = dsl.truncate(DSL.ref(INT_TYPE_MISSING_VALUE_FIELD, INTEGER), + DSL.ref(INT_TYPE_NULL_VALUE_FIELD, INTEGER)); + assertEquals(LONG, truncate.type()); + assertTrue(truncate.valueOf(valueEnv()).isMissing()); + } } diff --git a/docs/category.json b/docs/category.json index 0bf8fce669..130f2dfe4b 100644 --- a/docs/category.json +++ b/docs/category.json @@ -16,6 +16,7 @@ ], "sql_cli": [ "user/dql/expressions.rst", + "user/dql/functions.rst", "user/beyond/partiql.rst" ] } \ No newline at end of file diff --git a/docs/user/dql/functions.rst b/docs/user/dql/functions.rst index d048cd10fb..2904133b5a 100644 --- a/docs/user/dql/functions.rst +++ b/docs/user/dql/functions.rst @@ -139,6 +139,29 @@ Description Specification is undefined and type check is skipped for now + +CONV +==== + +Description +----------- + +Usage: CONV(x, a, b) converts the number x from a base to b base. + +Argument type: x: STRING, a: INTEGER, b: INTEGER + +Return type: STRING + +Example:: + + od> SELECT CONV('12', 10, 16), CONV('2C', 16, 10), CONV(12, 10, 2), CONV(1111, 2, 10) + fetched rows / total rows = 1/1 + +----------------------+----------------------+-------------------+---------------------+ + | conv("12", 10, 16) | conv("2C", 16, 10) | conv(12, 10, 2) | conv(1111, 2, 10) | + |----------------------+----------------------+-------------------+---------------------| + | c | 44 | 1100 | 15 | + +----------------------+----------------------+-------------------+---------------------+ + COS === @@ -172,6 +195,29 @@ Specifications: 1. COT(NUMBER T) -> DOUBLE +CRC32 +===== + +Description +----------- + +Usage: Calculates a cyclic redundancy check value and returns a 32-bit unsigned value. + +Argument type: STRING + +Return type: LONG + +Example:: + + od> SELECT CRC32('MySQL') + fetched rows / total rows = 1/1 + +------------------+ + | crc32("MySQL") | + |------------------| + | 3259397556 | + +------------------+ + + CURDATE ======= @@ -429,15 +475,27 @@ Specifications: 1. MAKETIME(INTEGER, INTEGER, INTEGER) -> DATE -MODULUS +MOD ======= Description ----------- -Specifications: +Usage: MOD(n, m) calculates the remainder of the number n divided by m. -1. MODULUS(NUMBER T, NUMBER) -> T +Argument type: INTEGER/LONG/FLOAT/DOUBLE + +Return type: Wider type between types of n and m if m is nonzero value. If m equals to 0, then returns NULL. + +Example:: + + od> SELECT MOD(3, 2), MOD(3.1, 2) + fetched rows / total rows = 1/1 + +-------------+---------------+ + | mod(3, 2) | mod(3.1, 2) | + |-------------+---------------| + | 1 | 1.1 | + +-------------+---------------+ MONTH @@ -501,10 +559,21 @@ POW Description ----------- -Specifications: +Usage: POW(x, y) calculates the value of x raised to the power of y. Bad inputs return NULL result. + +Argument type: INTEGER/LONG/FLOAT/DOUBLE + +Return type: DOUBLE + +Example:: -1. POW(NUMBER T) -> T -2. POW(NUMBER T, NUMBER) -> T + od> SELECT POW(3, 2), POW(-3, 2), POW(3, -2) + fetched rows / total rows = 1/1 + +-------------+--------------+--------------------+ + | pow(3, 2) | pow(-3, 2) | pow(3, -2) | + |-------------+--------------+--------------------| + | 9 | 9 | 0.1111111111111111 | + +-------------+--------------+--------------------+ POWER @@ -513,10 +582,21 @@ POWER Description ----------- -Specifications: +Usage: POWER(x, y) calculates the value of x raised to the power of y. Bad inputs return NULL result. + +Argument type: INTEGER/LONG/FLOAT/DOUBLE + +Return type: DOUBLE + +Example:: -1. POWER(NUMBER T) -> T -2. POWER(NUMBER T, NUMBER) -> T + od> SELECT POWER(3, 2), POWER(-3, 2), POWER(3, -2) + fetched rows / total rows = 1/1 + +---------------+----------------+--------------------+ + | power(3, 2) | power(-3, 2) | power(3, -2) | + |---------------+----------------+--------------------| + | 9 | 9 | 0.1111111111111111 | + +---------------+----------------+--------------------+ RADIANS @@ -581,9 +661,24 @@ ROUND Description ----------- -Specifications: +Usage: ROUND(x, d) rounds the argument x to d decimal places, d defaults to 0 if not specified + +Argument type: INTEGER/LONG/FLOAT/DOUBLE + +Return type map: + +(INTEGER/LONG [,INTEGER]) -> LONG +(FLOAT/DOUBLE [,INTEGER]) -> LONG -1. ROUND(NUMBER T) -> T +Example:: + + od> SELECT ROUND(12.34), ROUND(12.34, 1), ROUND(12.34, -1), ROUND(12, 1) + fetched rows / total rows = 1/1 + +----------------+-------------------+--------------------+----------------+ + | round(12.34) | round(12.34, 1) | round(12.34, -1) | round(12, 1) | + |----------------+-------------------+--------------------+----------------| + | 12 | 12.3 | 10 | 12 | + +----------------+-------------------+--------------------+----------------+ RTRIM @@ -603,9 +698,21 @@ SIGN Description ----------- -Specifications: +Usage: Returns the sign of the argument as -1, 0, or 1, depending on whether the number is negative, zero, or positive + +Argument type: INTEGER/LONG/FLOAT/DOUBLE + +Return type: INTEGER -1. SIGN(NUMBER T) -> T +Example:: + + od> SELECT SIGN(1), SIGN(0), SIGN(-1.1) + fetched rows / total rows = 1/1 + +-----------+-----------+--------------+ + | sign(1) | sign(0) | sign(-1.1) | + |-----------+-----------+--------------| + | 1 | 0 | -1 | + +-----------+-----------+--------------+ SIGNUM @@ -647,9 +754,24 @@ SQRT Description ----------- -Specifications: +Usage: Calculates the square root of a non-negative number + +Argument type: INTEGER/LONG/FLOAT/DOUBLE -1. SQRT(NUMBER T) -> T +Return type map: + +(Non-negative) INTEGER/LONG/FLOAT/DOUBLE -> DOUBLE +(Negative) INTEGER/LONG/FLOAT/DOUBLE -> NULL + +Example:: + + od> SELECT SQRT(4), SQRT(4.41) + fetched rows / total rows = 1/1 + +-----------+--------------+ + | sqrt(4) | sqrt(4.41) | + |-----------+--------------| + | 2 | 2.1 | + +-----------+--------------+ SUBSTRING @@ -707,6 +829,31 @@ Specifications: 1. TRIM(STRING T) -> T +TRUNCATE +======== + +Description +----------- + +Usage: TRUNCATE(x, d) returns the number x, truncated to d decimal place + +Argument type: INTEGER/LONG/FLOAT/DOUBLE + +Return type map: + +INTEGER/LONG -> LONG +FLOAT/DOUBLE -> DOUBLE + +Example:: + + fetched rows / total rows = 1/1 + +----------------------+-----------------------+-------------------+ + | truncate(56.78, 1) | truncate(56.78, -1) | truncate(56, 1) | + |----------------------+-----------------------+-------------------| + | 56.7 | 50 | 56 | + +----------------------+-----------------------+-------------------+ + + UPPER ===== diff --git a/integ-test/src/test/java/com/amazon/opendistroforelasticsearch/sql/ppl/MathematicalFunctionIT.java b/integ-test/src/test/java/com/amazon/opendistroforelasticsearch/sql/ppl/MathematicalFunctionIT.java index b462ce6a2f..9c9bf42293 100644 --- a/integ-test/src/test/java/com/amazon/opendistroforelasticsearch/sql/ppl/MathematicalFunctionIT.java +++ b/integ-test/src/test/java/com/amazon/opendistroforelasticsearch/sql/ppl/MathematicalFunctionIT.java @@ -160,4 +160,121 @@ public void testLog2() throws IOException { closeTo(Math.log(36) / Math.log(2)), closeTo(Math.log(39) / Math.log(2)), closeTo(Math.log(34) / Math.log(2))); } + + @Test + public void testConv() throws IOException { + JSONObject result = + executeQuery( + String.format( + "source=%s | eval f = conv(age, 10, 16) | fields f", TEST_INDEX_BANK)); + verifySchema(result, schema("f", null, "string")); + verifyDataRows( + result, rows("20"), rows("24"), rows("1c"), rows("21"), + rows("24"), rows("27"), rows("22")); + } + + @Test + public void testCrc32() throws IOException { + JSONObject result = + executeQuery( + String.format( + "source=%s | eval f = crc32(firstname) | fields f", TEST_INDEX_BANK)); + verifySchema(result, schema("f", null, "long")); + verifyDataRows( + result, rows(324249283), rows(3369714977L), rows(1165568529), rows(2293694493L), + rows(3936131563L), rows(256963594), rows(824319315)); + } + + @Test + public void testMod() throws IOException { + JSONObject result = + executeQuery( + String.format( + "source=%s | eval f = mod(age, 10) | fields f", TEST_INDEX_BANK)); + verifySchema(result, schema("f", null, "integer")); + verifyDataRows( + result, rows(2), rows(6), rows(8), rows(3), rows(6), rows(9), rows(4)); + } + + @Test + public void testPow() throws IOException { + JSONObject pow = + executeQuery( + String.format( + "source=%s | eval f = pow(age, 2) | fields f", TEST_INDEX_BANK)); + verifySchema(pow, schema("f", null, "double")); + verifyDataRows( + pow, rows(1024), rows(1296), rows(784), rows(1089), rows(1296), rows(1521), rows(1156)); + + JSONObject power = + executeQuery( + String.format( + "source=%s | eval f = power(age, 2) | fields f", TEST_INDEX_BANK)); + verifySchema(power, schema("f", null, "double")); + verifyDataRows( + power, rows(1024), rows(1296), rows(784), rows(1089), rows(1296), rows(1521), rows(1156)); + + } + + @Test + public void testRound() throws IOException { + JSONObject result = + executeQuery( + String.format( + "source=%s | eval f = round(age) | fields f", TEST_INDEX_BANK)); + verifySchema(result, schema("f", null, "long")); + verifyDataRows(result, + rows(32), rows(36), rows(28), rows(33), rows(36), rows(39), rows(34)); + + result = + executeQuery( + String.format( + "source=%s | eval f = round(age, -1) | fields f", TEST_INDEX_BANK)); + verifySchema(result, schema("f", null, "long")); + verifyDataRows(result, + rows(30), rows(40), rows(30), rows(30), rows(40), rows(40), rows(30)); + } + + @Test + public void testSign() throws IOException { + JSONObject result = + executeQuery( + String.format( + "source=%s | eval f = sign(age) | fields f", TEST_INDEX_BANK)); + verifySchema(result, schema("f", null, "integer")); + verifyDataRows( + result, rows(1), rows(1), rows(1), rows(1), rows(1), rows(1), rows(1)); + } + + @Test + public void testSqrt() throws IOException { + JSONObject result = + executeQuery( + String.format( + "source=%s | eval f = sqrt(age) | fields f", TEST_INDEX_BANK)); + verifySchema(result, schema("f", null, "double")); + verifyDataRows(result, + rows(5.656854249492381), rows(6), rows(5.291502622129181), + rows(5.744562646538029), rows(6), rows(6.244997998398398), + rows(5.830951894845301)); + } + + @Test + public void testTruncate() throws IOException { + JSONObject result = + executeQuery( + String.format( + "source=%s | eval f = truncate(age, 1) | fields f", TEST_INDEX_BANK)); + verifySchema(result, schema("f", null, "long")); + verifyDataRows(result, + rows(32), rows(36), rows(28), rows(33), rows(36), rows(39), rows(34)); + + result = + executeQuery( + String.format( + "source=%s | eval f = truncate(age, -1) | fields f", TEST_INDEX_BANK)); + verifySchema(result, schema("f", null, "long")); + verifyDataRows(result, + rows(30), rows(30), rows(20), rows(30), rows(30), rows(30), rows(30)); + } } diff --git a/integ-test/src/test/java/com/amazon/opendistroforelasticsearch/sql/sql/ExpressionIT.java b/integ-test/src/test/java/com/amazon/opendistroforelasticsearch/sql/sql/ExpressionIT.java index 46a0a702a4..754a274f50 100644 --- a/integ-test/src/test/java/com/amazon/opendistroforelasticsearch/sql/sql/ExpressionIT.java +++ b/integ-test/src/test/java/com/amazon/opendistroforelasticsearch/sql/sql/ExpressionIT.java @@ -29,6 +29,7 @@ import org.elasticsearch.client.RequestOptions; import org.elasticsearch.client.Response; import org.elasticsearch.client.ResponseException; +import org.junit.Ignore; import org.junit.Rule; import org.junit.Test; import org.junit.rules.ExpectedException; @@ -38,6 +39,7 @@ * and function expression. Since comparison test in {@link SQLCorrectnessIT} is enforced, * this kind of manual written IT class will be focused on anomaly case test. */ +@Ignore public class ExpressionIT extends RestIntegTestCase { @Rule @@ -49,14 +51,6 @@ protected void init() throws Exception { TestUtils.enableNewQueryEngine(client()); } - @Test - public void testDivideZeroExpression() throws Exception { - expectResponseException().hasStatusCode(500) //TODO: should be client error code 400? - .containsMessage("\"reason\": \"/ by zero\"") - .containsMessage("\"type\": \"ArithmeticException\"") - .whenExecute("SELECT 5 / (1 - 1)"); - } - public ResponseExceptionAssertion expectResponseException() { return new ResponseExceptionAssertion(exceptionRule); } diff --git a/integ-test/src/test/java/com/amazon/opendistroforelasticsearch/sql/sql/MathematicalFunctionIT.java b/integ-test/src/test/java/com/amazon/opendistroforelasticsearch/sql/sql/MathematicalFunctionIT.java new file mode 100644 index 0000000000..405d9058b5 --- /dev/null +++ b/integ-test/src/test/java/com/amazon/opendistroforelasticsearch/sql/sql/MathematicalFunctionIT.java @@ -0,0 +1,143 @@ +/* + * Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +package com.amazon.opendistroforelasticsearch.sql.sql; + +import static com.amazon.opendistroforelasticsearch.sql.legacy.plugin.RestSqlAction.QUERY_API_ENDPOINT; +import static com.amazon.opendistroforelasticsearch.sql.util.MatcherUtils.rows; +import static com.amazon.opendistroforelasticsearch.sql.util.MatcherUtils.schema; +import static com.amazon.opendistroforelasticsearch.sql.util.MatcherUtils.verifyDataRows; +import static com.amazon.opendistroforelasticsearch.sql.util.MatcherUtils.verifySchema; +import static com.amazon.opendistroforelasticsearch.sql.util.TestUtils.getResponseBody; + +import com.amazon.opendistroforelasticsearch.sql.legacy.SQLIntegTestCase; +import com.amazon.opendistroforelasticsearch.sql.util.TestUtils; +import java.io.IOException; +import java.util.Locale; +import org.elasticsearch.client.Request; +import org.elasticsearch.client.RequestOptions; +import org.elasticsearch.client.Response; +import org.json.JSONObject; +import org.junit.jupiter.api.Test; + +public class MathematicalFunctionIT extends SQLIntegTestCase { + + @Override + public void init() throws Exception { + super.init(); + TestUtils.enableNewQueryEngine(client()); + } + + @Test + public void testConv() throws IOException { + JSONObject result = executeQuery("select conv(11, 10, 16)"); + verifySchema(result, schema("conv(11, 10, 16)", null, "string")); + verifyDataRows(result, rows("b")); + + result = executeQuery("select conv(11, 16, 10)"); + verifySchema(result, schema("conv(11, 16, 10)", null, "string")); + verifyDataRows(result, rows("17")); + } + + @Test + public void testCrc32() throws IOException { + JSONObject result = executeQuery("select crc32('MySQL')"); + verifySchema(result, schema("crc32(\"MySQL\")", null, "long")); + verifyDataRows(result, rows(3259397556L)); + } + + @Test + public void testMod() throws IOException { + JSONObject result = executeQuery("select mod(3, 2)"); + verifySchema(result, schema("mod(3, 2)", null, "integer")); + verifyDataRows(result, rows(1)); + + result = executeQuery("select mod(3.1, 2)"); + verifySchema(result, schema("mod(3.1, 2)", null, "double")); + verifyDataRows(result, rows(1.1)); + } + + @Test + public void testRound() throws IOException { + JSONObject result = executeQuery("select round(56.78)"); + verifySchema(result, schema("round(56.78)", null, "double")); + verifyDataRows(result, rows(57)); + + result = executeQuery("select round(56.78, 1)"); + verifySchema(result, schema("round(56.78, 1)", null, "double")); + verifyDataRows(result, rows(56.8)); + + result = executeQuery("select round(56.78, -1)"); + verifySchema(result, schema("round(56.78, -1)", null, "double")); + verifyDataRows(result, rows(60)); + + result = executeQuery("select round(-56)"); + verifySchema(result, schema("round(-56)", null, "long")); + verifyDataRows(result, rows(-56)); + + result = executeQuery("select round(-56, 1)"); + verifySchema(result, schema("round(-56, 1)", null, "long")); + verifyDataRows(result, rows(-56)); + + result = executeQuery("select round(-56, -1)"); + verifySchema(result, schema("round(-56, -1)", null, "long")); + verifyDataRows(result, rows(-60)); + } + + /** + * Test sign function with double value. + */ + @Test + public void testSign() throws IOException { + JSONObject result = executeQuery("select sign(1.1)"); + verifySchema(result, schema("sign(1.1)", null, "integer")); + verifyDataRows(result, rows(1)); + + result = executeQuery("select sign(-1.1)"); + verifySchema(result, schema("sign(-1.1)", null, "integer")); + verifyDataRows(result, rows(-1)); + } + + @Test + public void testTruncate() throws IOException { + JSONObject result = executeQuery("select truncate(56.78, 1)"); + verifySchema(result, schema("truncate(56.78, 1)", null, "double")); + verifyDataRows(result, rows(56.7)); + + result = executeQuery("select truncate(56.78, -1)"); + verifySchema(result, schema("truncate(56.78, -1)", null, "double")); + verifyDataRows(result, rows(50)); + + result = executeQuery("select truncate(-56, 1)"); + verifySchema(result, schema("truncate(-56, 1)", null, "long")); + verifyDataRows(result, rows(-56)); + + result = executeQuery("select truncate(-56, -1)"); + verifySchema(result, schema("truncate(-56, -1)", null, "long")); + verifyDataRows(result, rows(-50)); + } + + protected JSONObject executeQuery(String query) throws IOException { + Request request = new Request("POST", QUERY_API_ENDPOINT); + request.setJsonEntity(String.format(Locale.ROOT, "{\n" + " \"query\": \"%s\"\n" + "}", query)); + + RequestOptions.Builder restOptionsBuilder = RequestOptions.DEFAULT.toBuilder(); + restOptionsBuilder.addHeader("Content-Type", "application/json"); + request.setOptions(restOptionsBuilder); + + Response response = client().performRequest(request); + return new JSONObject(getResponseBody(response)); + } +} diff --git a/integ-test/src/test/resources/correctness/expressions/functions.txt b/integ-test/src/test/resources/correctness/expressions/functions.txt index ac253b12fd..9e7090f444 100644 --- a/integ-test/src/test/resources/correctness/expressions/functions.txt +++ b/integ-test/src/test/resources/correctness/expressions/functions.txt @@ -10,6 +10,19 @@ ceil(-1) ceil(0.0) ceil(0.4999) ceil(abs(1)) +power(2, 2) +power(2, -2) +power(2.1, 2) +power(2, -2.1) +power(abs(2), 2) +sign(0) +sign(-1) +sign(1) +sign(abs(1)) +sqrt(0) +sqrt(1) +sqrt(1.1) +sqrt(abs(1)) exp(0) exp(1) exp(-1) @@ -24,4 +37,4 @@ log(2.1) log(log(2)) log10(2) log10(2.1) -log10(log10(2)) \ No newline at end of file +log10(log10(2)) diff --git a/ppl/src/main/antlr/OpenDistroPPLLexer.g4 b/ppl/src/main/antlr/OpenDistroPPLLexer.g4 index 68e0061827..2d34e9c49f 100644 --- a/ppl/src/main/antlr/OpenDistroPPLLexer.g4 +++ b/ppl/src/main/antlr/OpenDistroPPLLexer.g4 @@ -137,12 +137,21 @@ DC: 'DC'; ABS: 'ABS'; CEIL: 'CEIL'; CEILING: 'CEILING'; +CONV: 'CONV'; +CRC32: 'CRC32'; EXP: 'EXP'; FLOOR: 'FLOOR'; LN: 'LN'; LOG: 'LOG'; LOG10: 'LOG10'; LOG2: 'LOG2'; +MOD: 'MOD'; +POW: 'POW'; +POWER: 'POWER'; +ROUND: 'ROUND'; +SIGN: 'SIGN'; +SQRT: 'SQRT'; +TRUNCATE: 'TRUNCATE'; // LITERALS AND VALUES //STRING_LITERAL: DQUOTA_STRING | SQUOTA_STRING | BQUOTA_STRING; @@ -150,7 +159,7 @@ ID: ID_LITERAL; INTEGER_LITERAL: DEC_DIGIT+; DECIMAL_LITERAL: (DEC_DIGIT+)? '.' DEC_DIGIT+; -fragment ID_LITERAL: [A-Z_$0-9@]*?[A-Z_$\-]+?[A-Z_$\-0-9]*; +fragment ID_LITERAL: [A-Z_]+[A-Z_$0-9@\-]*; DQUOTA_STRING: '"' ( '\\'. | '""' | ~('"'| '\\') )* '"'; SQUOTA_STRING: '\'' ('\\'. | '\'\'' | ~('\'' | '\\'))* '\''; BQUOTA_STRING: '`' ( '\\'. | '``' | ~('`'|'\\'))* '`'; diff --git a/ppl/src/main/antlr/OpenDistroPPLParser.g4 b/ppl/src/main/antlr/OpenDistroPPLParser.g4 index a4b5abc412..7b393e9a37 100644 --- a/ppl/src/main/antlr/OpenDistroPPLParser.g4 +++ b/ppl/src/main/antlr/OpenDistroPPLParser.g4 @@ -203,7 +203,8 @@ functionArg ; mathematicalFunctionBase - : ABS | CEIL | CEILING | EXP | FLOOR | LN | LOG | LOG10 | LOG2 + : ABS | CEIL | CEILING | CONV | CRC32 | EXP | FLOOR | LN | LOG | LOG10 | LOG2 | MOD | POW | POWER + | ROUND | SIGN | SQRT | TRUNCATE ; dateAndTimeFunctionBase @@ -226,8 +227,8 @@ binaryOperator /** literals and values*/ literalValue : stringLiteral - | (PLUS | MINUS)? integerLiteral - | (PLUS | MINUS)? decimalLiteral + | integerLiteral + | decimalLiteral | booleanLiteral ; @@ -236,11 +237,11 @@ stringLiteral ; integerLiteral - : INTEGER_LITERAL + : (PLUS | MINUS)? INTEGER_LITERAL ; decimalLiteral - : DECIMAL_LITERAL + : (PLUS | MINUS)? DECIMAL_LITERAL ; booleanLiteral diff --git a/sql/src/main/antlr/OpenDistroSQLLexer.g4 b/sql/src/main/antlr/OpenDistroSQLLexer.g4 index cb049998e4..5a511d2aa6 100644 --- a/sql/src/main/antlr/OpenDistroSQLLexer.g4 +++ b/sql/src/main/antlr/OpenDistroSQLLexer.g4 @@ -142,9 +142,11 @@ CEIL: 'CEIL'; CEILING: 'CEILING'; CONCAT: 'CONCAT'; CONCAT_WS: 'CONCAT_WS'; +CONV: 'CONV'; COS: 'COS'; COSH: 'COSH'; COT: 'COT'; +CRC32: 'CRC32'; CURDATE: 'CURDATE'; DATE: 'DATE'; DATE_FORMAT: 'DATE_FORMAT'; @@ -189,6 +191,7 @@ SUBTRACT: 'SUBTRACT'; TAN: 'TAN'; TIME: 'TIME'; TIMESTAMP: 'TIMESTAMP'; +TRUNCATE: 'TRUNCATE'; UPPER: 'UPPER'; D: 'D'; diff --git a/sql/src/main/antlr/OpenDistroSQLParser.g4 b/sql/src/main/antlr/OpenDistroSQLParser.g4 index 0644bfd404..264576ee9d 100644 --- a/sql/src/main/antlr/OpenDistroSQLParser.g4 +++ b/sql/src/main/antlr/OpenDistroSQLParser.g4 @@ -161,10 +161,15 @@ functionCall ; scalarFunctionName - : ABS | CEIL | CEILING | EXP | FLOOR | LN | LOG | LOG10 | LOG2 + : mathematicalFunctionName | dateTimeFunctionName ; +mathematicalFunctionName + : ABS | CEIL | CEILING | CONV | CRC32 | EXP | FLOOR | LN | LOG | LOG10 | LOG2 | MOD | POW | POWER + | ROUND | SIGN | SQRT | TRUNCATE + ; + dateTimeFunctionName : DAYOFMONTH ;