Skip to content

Commit

Permalink
Add CBRT to the V2 engine (#166)
Browse files Browse the repository at this point in the history
Signed-off-by: Margarit Hakobyan <margarith@bitquilltech.com>
  • Loading branch information
margarit-h committed Nov 17, 2022
1 parent 662a938 commit eb2435a
Show file tree
Hide file tree
Showing 7 changed files with 129 additions and 3 deletions.
4 changes: 4 additions & 0 deletions core/src/main/java/org/opensearch/sql/expression/DSL.java
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,10 @@ public FunctionExpression sqrt(Expression... expressions) {
return function(BuiltinFunctionName.SQRT, expressions);
}

public FunctionExpression cbrt(Expression... expressions) {
return function(BuiltinFunctionName.CBRT, expressions);
}

public FunctionExpression truncate(Expression... expressions) {
return function(BuiltinFunctionName.TRUNCATE, expressions);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ public enum BuiltinFunctionName {
ROUND(FunctionName.of("round")),
SIGN(FunctionName.of("sign")),
SQRT(FunctionName.of("sqrt")),
CBRT(FunctionName.of("cbrt")),
TRUNCATE(FunctionName.of("truncate")),

ACOS(FunctionName.of("acos")),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ public class MathematicalFunction {
*/
public static void register(BuiltinFunctionRepository repository) {
repository.register(abs());
repository.register(cbrt());
repository.register(ceil());
repository.register(ceiling());
repository.register(conv());
Expand Down Expand Up @@ -471,6 +472,20 @@ private static DefaultFunctionResolver sqrt() {
DOUBLE, type)).collect(Collectors.toList()));
}

/**
* Definition of cbrt(x) function.
* Calculate the cube root of a number x
* The supported signature is
* INTEGER/LONG/FLOAT/DOUBLE -> DOUBLE
*/
private static DefaultFunctionResolver cbrt() {
return FunctionDSL.define(BuiltinFunctionName.CBRT.getName(),
ExprCoreType.numberTypes().stream()
.map(type -> FunctionDSL.impl(FunctionDSL.nullMissingHandling(
v -> new ExprDoubleValue(Math.cbrt(v.doubleValue()))),
DOUBLE, type)).collect(Collectors.toList()));
}

/**
* Definition of truncate(x, d) function.
* Returns the number x, truncated to d decimal places
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2327,4 +2327,79 @@ public void tan_missing_value() {
assertEquals(DOUBLE, tan.type());
assertTrue(tan.valueOf(valueEnv()).isMissing());
}

/**
* Test cbrt with int value.
*/
@ParameterizedTest(name = "cbrt({0})")
@ValueSource(ints = {1, 2})
public void cbrt_int_value(Integer value) {
FunctionExpression cbrt = dsl.cbrt(DSL.literal(value));
assertThat(cbrt.valueOf(valueEnv()), allOf(hasType(DOUBLE), hasValue(Math.cbrt(value))));
assertEquals(String.format("cbrt(%s)", value), cbrt.toString());
}

/**
* Test cbrt with long value.
*/
@ParameterizedTest(name = "cbrt({0})")
@ValueSource(longs = {1L, 2L})
public void cbrt_long_value(Long value) {
FunctionExpression cbrt = dsl.cbrt(DSL.literal(value));
assertThat(cbrt.valueOf(valueEnv()), allOf(hasType(DOUBLE), hasValue(Math.cbrt(value))));
assertEquals(String.format("cbrt(%s)", value), cbrt.toString());
}

/**
* Test cbrt with float value.
*/
@ParameterizedTest(name = "cbrt({0})")
@ValueSource(floats = {1F, 2F})
public void cbrt_float_value(Float value) {
FunctionExpression cbrt = dsl.cbrt(DSL.literal(value));
assertThat(cbrt.valueOf(valueEnv()), allOf(hasType(DOUBLE), hasValue(Math.cbrt(value))));
assertEquals(String.format("cbrt(%s)", value), cbrt.toString());
}

/**
* Test cbrt with double value.
*/
@ParameterizedTest(name = "cbrt({0})")
@ValueSource(doubles = {1D, 2D, Double.MAX_VALUE, Double.MIN_VALUE})
public void cbrt_double_value(Double value) {
FunctionExpression cbrt = dsl.cbrt(DSL.literal(value));
assertThat(cbrt.valueOf(valueEnv()), allOf(hasType(DOUBLE), hasValue(Math.cbrt(value))));
assertEquals(String.format("cbrt(%s)", value), cbrt.toString());
}

/**
* Test cbrt with negative value.
*/
@ParameterizedTest(name = "cbrt({0})")
@ValueSource(doubles = {-1D, -2D})
public void cbrt_negative_value(Double value) {
FunctionExpression cbrt = dsl.cbrt(DSL.literal(value));
assertThat(cbrt.valueOf(valueEnv()), allOf(hasType(DOUBLE), hasValue(Math.cbrt(value))));
assertEquals(String.format("cbrt(%s)", value), cbrt.toString());
}

/**
* Test cbrt with null value.
*/
@Test
public void cbrt_null_value() {
FunctionExpression cbrt = dsl.cbrt(DSL.ref(INT_TYPE_NULL_VALUE_FIELD, INTEGER));
assertEquals(DOUBLE, cbrt.type());
assertTrue(cbrt.valueOf(valueEnv()).isNull());
}

/**
* Test cbrt with missing value.
*/
@Test
public void cbrt_missing_value() {
FunctionExpression cbrt = dsl.cbrt(DSL.ref(INT_TYPE_MISSING_VALUE_FIELD, INTEGER));
assertEquals(DOUBLE, cbrt.type());
assertTrue(cbrt.valueOf(valueEnv()).isMissing());
}
}
19 changes: 17 additions & 2 deletions docs/user/dql/functions.rst
Original file line number Diff line number Diff line change
Expand Up @@ -209,9 +209,24 @@ CBRT
Description
>>>>>>>>>>>

Specifications:
Usage: CBRT(number) calculates the cube root of a number

Argument type: INTEGER/LONG/FLOAT/DOUBLE

Return type: DOUBLE

1. CBRT(NUMBER T) -> T
(Non-negative) INTEGER/LONG/FLOAT/DOUBLE -> DOUBLE
(Negative) INTEGER/LONG/FLOAT/DOUBLE -> DOUBLE

Example::

opensearchsql> SELECT CBRT(8), CBRT(9.261), CBRT(-27);
fetched rows / total rows = 1/1
+-----------+---------------+-------------+
| CBRT(8) | CBRT(9.261) | CBRT(-27) |
|-----------+---------------+-------------|
| 2.0 | 2.1 | -3.0 |
+-----------+---------------+-------------+


CEIL
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -162,4 +162,20 @@ protected JSONObject executeQuery(String query) throws IOException {
Response response = client().performRequest(request);
return new JSONObject(getResponseBody(response));
}


@Test
public void testCbrt() throws IOException {
JSONObject result = executeQuery("select cbrt(8)");
verifySchema(result, schema("cbrt(8)", "double"));
verifyDataRows(result, rows(2.0));

result = executeQuery("select cbrt(9.261)");
verifySchema(result, schema("cbrt(9.261)", "double"));
verifyDataRows(result, rows(2.1));

result = executeQuery("select cbrt(-27)");
verifySchema(result, schema("cbrt(-27)", "double"));
verifyDataRows(result, rows(-3.0));
}
}
2 changes: 1 addition & 1 deletion sql/src/main/antlr/OpenSearchSQLParser.g4
Original file line number Diff line number Diff line change
Expand Up @@ -383,7 +383,7 @@ aggregationFunctionName
;

mathematicalFunctionName
: ABS | CEIL | CEILING | CONV | CRC32 | E | EXP | FLOOR | LN | LOG | LOG10 | LOG2 | MOD | PI | POW | POWER
: ABS | CBRT | CEIL | CEILING | CONV | CRC32 | E | EXP | FLOOR | LN | LOG | LOG10 | LOG2 | MOD | PI | POW | POWER
| RAND | ROUND | SIGN | SQRT | TRUNCATE
| trigonometricFunctionName
;
Expand Down

0 comments on commit eb2435a

Please sign in to comment.