From fb98ea54ffa4c7ec5f373a6d9ed56eb239c9fac1 Mon Sep 17 00:00:00 2001 From: Harold Wang <74381974+harold-wang@users.noreply.github.com> Date: Tue, 15 Dec 2020 10:10:00 -0800 Subject: [PATCH] Enable Date type input in function Count() (#931) * Enable count(Date) Add IT Add Comparsion Test * Enable count(Date) Add IT * Add comparsion test file 916.txt * Consolidate count function to accept all field type --- .../aggregation/AggregatorFunction.java | 34 +++++++------------ .../aggregation/CountAggregatorTest.java | 22 ++++++++++++ .../resources/correctness/bugfixes/916.txt | 1 + .../correctness/queries/aggregation.txt | 1 + 4 files changed, 37 insertions(+), 21 deletions(-) create mode 100644 integ-test/src/test/resources/correctness/bugfixes/916.txt diff --git a/core/src/main/java/com/amazon/opendistroforelasticsearch/sql/expression/aggregation/AggregatorFunction.java b/core/src/main/java/com/amazon/opendistroforelasticsearch/sql/expression/aggregation/AggregatorFunction.java index a09a2c6833..e467c38585 100644 --- a/core/src/main/java/com/amazon/opendistroforelasticsearch/sql/expression/aggregation/AggregatorFunction.java +++ b/core/src/main/java/com/amazon/opendistroforelasticsearch/sql/expression/aggregation/AggregatorFunction.java @@ -28,6 +28,8 @@ import static com.amazon.opendistroforelasticsearch.sql.data.type.ExprCoreType.TIME; import static com.amazon.opendistroforelasticsearch.sql.data.type.ExprCoreType.TIMESTAMP; +import com.amazon.opendistroforelasticsearch.sql.data.type.ExprCoreType; +import com.amazon.opendistroforelasticsearch.sql.data.type.ExprType; import com.amazon.opendistroforelasticsearch.sql.expression.function.BuiltinFunctionName; import com.amazon.opendistroforelasticsearch.sql.expression.function.BuiltinFunctionRepository; import com.amazon.opendistroforelasticsearch.sql.expression.function.FunctionBuilder; @@ -35,7 +37,13 @@ import com.amazon.opendistroforelasticsearch.sql.expression.function.FunctionResolver; import com.amazon.opendistroforelasticsearch.sql.expression.function.FunctionSignature; import com.google.common.collect.ImmutableMap; + +import java.util.ArrayList; import java.util.Collections; +import java.util.Date; +import java.util.List; +import java.util.stream.Collectors; + import lombok.experimental.UtilityClass; /** @@ -73,27 +81,11 @@ private static FunctionResolver avg() { private static FunctionResolver count() { FunctionName functionName = BuiltinFunctionName.COUNT.getName(); - return new FunctionResolver( - functionName, - new ImmutableMap.Builder() - .put(new FunctionSignature(functionName, Collections.singletonList(INTEGER)), - arguments -> new CountAggregator(arguments, INTEGER)) - .put(new FunctionSignature(functionName, Collections.singletonList(LONG)), - arguments -> new CountAggregator(arguments, INTEGER)) - .put(new FunctionSignature(functionName, Collections.singletonList(FLOAT)), - arguments -> new CountAggregator(arguments, INTEGER)) - .put(new FunctionSignature(functionName, Collections.singletonList(DOUBLE)), - arguments -> new CountAggregator(arguments, INTEGER)) - .put(new FunctionSignature(functionName, Collections.singletonList(STRING)), - arguments -> new CountAggregator(arguments, INTEGER)) - .put(new FunctionSignature(functionName, Collections.singletonList(STRUCT)), - arguments -> new CountAggregator(arguments, INTEGER)) - .put(new FunctionSignature(functionName, Collections.singletonList(ARRAY)), - arguments -> new CountAggregator(arguments, INTEGER)) - .put(new FunctionSignature(functionName, Collections.singletonList(BOOLEAN)), - arguments -> new CountAggregator(arguments, INTEGER)) - .build() - ); + FunctionResolver functionResolver = new FunctionResolver(functionName, + ExprCoreType.coreTypes().stream().collect(Collectors.toMap( + type -> new FunctionSignature(functionName, Collections.singletonList(type)), + type -> arguments -> new CountAggregator(arguments, INTEGER)))); + return functionResolver; } private static FunctionResolver sum() { diff --git a/core/src/test/java/com/amazon/opendistroforelasticsearch/sql/expression/aggregation/CountAggregatorTest.java b/core/src/test/java/com/amazon/opendistroforelasticsearch/sql/expression/aggregation/CountAggregatorTest.java index 4f42bec8fa..1190cc01df 100644 --- a/core/src/test/java/com/amazon/opendistroforelasticsearch/sql/expression/aggregation/CountAggregatorTest.java +++ b/core/src/test/java/com/amazon/opendistroforelasticsearch/sql/expression/aggregation/CountAggregatorTest.java @@ -17,12 +17,16 @@ import static com.amazon.opendistroforelasticsearch.sql.data.type.ExprCoreType.ARRAY; import static com.amazon.opendistroforelasticsearch.sql.data.type.ExprCoreType.BOOLEAN; +import static com.amazon.opendistroforelasticsearch.sql.data.type.ExprCoreType.DATE; +import static com.amazon.opendistroforelasticsearch.sql.data.type.ExprCoreType.DATETIME; import static com.amazon.opendistroforelasticsearch.sql.data.type.ExprCoreType.DOUBLE; import static com.amazon.opendistroforelasticsearch.sql.data.type.ExprCoreType.FLOAT; import static com.amazon.opendistroforelasticsearch.sql.data.type.ExprCoreType.INTEGER; import static com.amazon.opendistroforelasticsearch.sql.data.type.ExprCoreType.LONG; import static com.amazon.opendistroforelasticsearch.sql.data.type.ExprCoreType.STRING; import static com.amazon.opendistroforelasticsearch.sql.data.type.ExprCoreType.STRUCT; +import static com.amazon.opendistroforelasticsearch.sql.data.type.ExprCoreType.TIMESTAMP; + import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertThrows; @@ -58,6 +62,24 @@ public void count_double_field_expression() { assertEquals(4, result.value()); } + @Test + public void count_date_field_expression() { + ExprValue result = aggregation(dsl.count(DSL.ref("date_value", DATE)), tuples); + assertEquals(4, result.value()); + } + + @Test + public void count_timestamp_field_expression() { + ExprValue result = aggregation(dsl.count(DSL.ref("timestamp_value", TIMESTAMP)), tuples); + assertEquals(4, result.value()); + } + + @Test + public void count_datetime_field_expression() { + ExprValue result = aggregation(dsl.count(DSL.ref("datetime_value", DATETIME)), tuples); + assertEquals(4, result.value()); + } + @Test public void count_arithmetic_expression() { ExprValue result = aggregation(dsl.count( diff --git a/integ-test/src/test/resources/correctness/bugfixes/916.txt b/integ-test/src/test/resources/correctness/bugfixes/916.txt new file mode 100644 index 0000000000..44a715f0eb --- /dev/null +++ b/integ-test/src/test/resources/correctness/bugfixes/916.txt @@ -0,0 +1 @@ +SELECT COUNT(timestamp) FROM kibana_sample_data_flights diff --git a/integ-test/src/test/resources/correctness/queries/aggregation.txt b/integ-test/src/test/resources/correctness/queries/aggregation.txt index 3a2081d9a8..e3878be86c 100644 --- a/integ-test/src/test/resources/correctness/queries/aggregation.txt +++ b/integ-test/src/test/resources/correctness/queries/aggregation.txt @@ -1,4 +1,5 @@ SELECT COUNT(AvgTicketPrice) FROM kibana_sample_data_flights +SELECT count(timestamp) from kibana_sample_data_flights SELECT AVG(AvgTicketPrice) FROM kibana_sample_data_flights SELECT SUM(AvgTicketPrice) FROM kibana_sample_data_flights SELECT MAX(AvgTicketPrice) FROM kibana_sample_data_flights