From 19a9bb8a516ced26c8528f50b5d3f4cfa9d245a1 Mon Sep 17 00:00:00 2001 From: Avinash Jain Date: Thu, 10 Nov 2022 15:50:48 +0000 Subject: [PATCH] Adding support of ARRAY_SORT_DESC for Presto --- .../src/main/sphinx/functions/array.rst | 9 ++++++ .../scalar/sql/ArraySqlFunctions.java | 10 +++++++ .../scalar/sql/TestArraySqlFunctions.java | 28 +++++++++++++++++++ 3 files changed, 47 insertions(+) diff --git a/presto-docs/src/main/sphinx/functions/array.rst b/presto-docs/src/main/sphinx/functions/array.rst index edfafb19f3876..329b2f6a914f4 100644 --- a/presto-docs/src/main/sphinx/functions/array.rst +++ b/presto-docs/src/main/sphinx/functions/array.rst @@ -151,6 +151,15 @@ Array Functions -1, IF(cardinality(x) = cardinality(y), 0, 1))); -- [[1, 2], [2, 3, 1], [4, 2, 1, 4]] +.. function:: array_sort_desc(x) -> array + + Returns the ``array`` sorted in the descending order. Elements of the ``array`` must be orderable. + Null elements will be placed at the end of the returned array. + + SELECT array_sort_desc(ARRAY [100, 1, 10, 50]); -- [100, 50, 10, 1] + SELECT array_sort_desc(ARRAY [null, 100, null, 1, 10, 50]); -- [100, 50, 10, 1, null, null] + SELECT array_sort_desc(ARRAY [ARRAY ["a", null], null, ARRAY ["a"]); -- [["a", null], ["a"], null] + .. function:: array_sum(array(T)) -> bigint/double Returns the sum of all non-null elements of the ``array``. If there is no non-null elements, returns ``0``. diff --git a/presto-main/src/main/java/com/facebook/presto/operator/scalar/sql/ArraySqlFunctions.java b/presto-main/src/main/java/com/facebook/presto/operator/scalar/sql/ArraySqlFunctions.java index b9057b6cd8e37..cf31f9abdd587 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/scalar/sql/ArraySqlFunctions.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/scalar/sql/ArraySqlFunctions.java @@ -116,4 +116,14 @@ public static String arrayMinBy() "array_min(zip_with(transform(input, f), sequence(1, cardinality(input)), (x, y)->IF(x IS NULL, NULL, (x, y))))[2]" + "]"; } + + @SqlInvokedScalarFunction(value = "array_sort_desc", deterministic = true, calledOnNullInput = true) + @Description("Sorts the given array in descending order according to the natural ordering of its elements.") + @TypeParameter("T") + @SqlParameter(name = "input", type = "array(T)") + @SqlType("array") + public static String arraySortDesc() + { + return "RETURN reverse(array_sort(remove_nulls(input))) || filter(input, x -> x is null)"; + } } diff --git a/presto-main/src/test/java/com/facebook/presto/operator/scalar/sql/TestArraySqlFunctions.java b/presto-main/src/test/java/com/facebook/presto/operator/scalar/sql/TestArraySqlFunctions.java index 42ed89c10df37..a8097c3b67fb4 100644 --- a/presto-main/src/test/java/com/facebook/presto/operator/scalar/sql/TestArraySqlFunctions.java +++ b/presto-main/src/test/java/com/facebook/presto/operator/scalar/sql/TestArraySqlFunctions.java @@ -17,6 +17,7 @@ import com.facebook.presto.common.type.RowType; import com.facebook.presto.operator.scalar.AbstractTestFunctions; import com.facebook.presto.spi.StandardErrorCode; +import com.facebook.presto.sql.analyzer.SemanticErrorCode; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import org.testng.annotations.Test; @@ -26,10 +27,12 @@ import static com.facebook.presto.common.type.BooleanType.BOOLEAN; import static com.facebook.presto.common.type.DoubleType.DOUBLE; import static com.facebook.presto.common.type.IntegerType.INTEGER; +import static com.facebook.presto.common.type.UnknownType.UNKNOWN; import static com.facebook.presto.common.type.VarcharType.VARCHAR; import static com.facebook.presto.common.type.VarcharType.createVarcharType; import static com.facebook.presto.util.StructuralTestUtil.mapType; import static java.util.Arrays.asList; +import static java.util.Collections.emptyList; import static java.util.Collections.singletonList; public class TestArraySqlFunctions @@ -214,4 +217,29 @@ public void testArrayMinBy() assertFunction("ARRAY_MIN_BY(ARRAY [cast(null as double), cast(null as double)], i -> i)", DOUBLE, null); assertFunction("ARRAY_MIN_BY(cast(null as array(double)), i -> i)", DOUBLE, null); } + + @Test + public void testArraySortDesc() + { + assertFunction("ARRAY_SORT_DESC(ARRAY [100, 1, 10, 50])", new ArrayType(INTEGER), ImmutableList.of(100, 50, 10, 1)); + assertFunction("ARRAY_SORT_DESC(ARRAY [null, null, 100, 1, 10, 50])", new ArrayType(INTEGER), asList(100, 50, 10, 1, null, null)); + assertFunction("ARRAY_SORT_DESC(ARRAY [double'1.0', double'2.0'])", new ArrayType(DOUBLE), ImmutableList.of(2.0d, 1.0d)); + assertFunction("ARRAY_SORT_DESC(ARRAY [double'1.0', double'2.0'])", new ArrayType(DOUBLE), ImmutableList.of(2.0d, 1.0d)); + assertFunction("ARRAY_SORT_DESC(ARRAY [null, double'-3.0', double'2.0', null])", new ArrayType(DOUBLE), asList(2.0d, -3.0d, null, null)); + assertFunction("ARRAY_SORT_DESC(ARRAY ['a', 'bb', 'c'])", new ArrayType(createVarcharType(2)), ImmutableList.of("c", "bb", "a")); + assertFunction("ARRAY_SORT_DESC(ARRAY ['a', 'bb', 'c', null])", new ArrayType(createVarcharType(2)), asList("c", "bb", "a", null)); + assertFunction("ARRAY_SORT_DESC(ARRAY [null, null, null])", new ArrayType(UNKNOWN), asList(null, null, null)); + assertFunction("ARRAY_SORT_DESC(ARRAY [])", new ArrayType(UNKNOWN), emptyList()); + assertFunction("ARRAY_SORT_DESC(null)", new ArrayType(UNKNOWN), null); + assertFunction("ARRAY_SORT_DESC(" + + "ARRAY [ARRAY['a'], ARRAY['b', 'b'], ARRAY['c']])", + new ArrayType(new ArrayType(createVarcharType(1))), + ImmutableList.of(ImmutableList.of("c"), ImmutableList.of("b", "b"), ImmutableList.of("a"))); + assertFunction("ARRAY_SORT_DESC(" + + "ARRAY [ARRAY['a'], ARRAY['b', 'b'], ARRAY['c'], null, null, ARRAY['a', NULL]])", + new ArrayType(new ArrayType(createVarcharType(1))), + asList(singletonList("c"), ImmutableList.of("b", "b"), asList("a", null), singletonList("a"), null, null)); + assertInvalidFunction("ARRAY_SORT_DESC(ARRAY [ROW('a', 1), ROW('a', null), null, ROW('a', 0)])", StandardErrorCode.INVALID_FUNCTION_ARGUMENT); + assertInvalidFunction("ARRAY_SORT_DESC(ARRAY [MAP(ARRAY['foo', 'bar'], ARRAY[1, 2]), MAP(ARRAY['foo', 'bar'], ARRAY[0, 3])])", SemanticErrorCode.FUNCTION_NOT_FOUND); + } }