Skip to content

Commit

Permalink
Adding support of ARRAY_SORT_DESC for Presto
Browse files Browse the repository at this point in the history
  • Loading branch information
jainavi17 authored and Sreeni Viswanadha committed Nov 11, 2022
1 parent 43a5fa0 commit 19a9bb8
Show file tree
Hide file tree
Showing 3 changed files with 47 additions and 0 deletions.
9 changes: 9 additions & 0 deletions presto-docs/src/main/sphinx/functions/array.rst
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,15 @@ Array Functions
-1,
IF(cardinality(x) = cardinality(y), 0, 1))); -- [[1, 2], [2, 3, 1], [4, 2, 1, 4]]

.. function:: array_sort_desc(x) -> array

Returns the ``array`` sorted in the descending order. Elements of the ``array`` must be orderable.
Null elements will be placed at the end of the returned array.

SELECT array_sort_desc(ARRAY [100, 1, 10, 50]); -- [100, 50, 10, 1]
SELECT array_sort_desc(ARRAY [null, 100, null, 1, 10, 50]); -- [100, 50, 10, 1, null, null]
SELECT array_sort_desc(ARRAY [ARRAY ["a", null], null, ARRAY ["a"]); -- [["a", null], ["a"], null]

.. function:: array_sum(array(T)) -> bigint/double

Returns the sum of all non-null elements of the ``array``. If there is no non-null elements, returns ``0``.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -116,4 +116,14 @@ public static String arrayMinBy()
"array_min(zip_with(transform(input, f), sequence(1, cardinality(input)), (x, y)->IF(x IS NULL, NULL, (x, y))))[2]" +
"]";
}

@SqlInvokedScalarFunction(value = "array_sort_desc", deterministic = true, calledOnNullInput = true)
@Description("Sorts the given array in descending order according to the natural ordering of its elements.")
@TypeParameter("T")
@SqlParameter(name = "input", type = "array(T)")
@SqlType("array<T>")
public static String arraySortDesc()
{
return "RETURN reverse(array_sort(remove_nulls(input))) || filter(input, x -> x is null)";
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import com.facebook.presto.common.type.RowType;
import com.facebook.presto.operator.scalar.AbstractTestFunctions;
import com.facebook.presto.spi.StandardErrorCode;
import com.facebook.presto.sql.analyzer.SemanticErrorCode;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import org.testng.annotations.Test;
Expand All @@ -26,10 +27,12 @@
import static com.facebook.presto.common.type.BooleanType.BOOLEAN;
import static com.facebook.presto.common.type.DoubleType.DOUBLE;
import static com.facebook.presto.common.type.IntegerType.INTEGER;
import static com.facebook.presto.common.type.UnknownType.UNKNOWN;
import static com.facebook.presto.common.type.VarcharType.VARCHAR;
import static com.facebook.presto.common.type.VarcharType.createVarcharType;
import static com.facebook.presto.util.StructuralTestUtil.mapType;
import static java.util.Arrays.asList;
import static java.util.Collections.emptyList;
import static java.util.Collections.singletonList;

public class TestArraySqlFunctions
Expand Down Expand Up @@ -214,4 +217,29 @@ public void testArrayMinBy()
assertFunction("ARRAY_MIN_BY(ARRAY [cast(null as double), cast(null as double)], i -> i)", DOUBLE, null);
assertFunction("ARRAY_MIN_BY(cast(null as array(double)), i -> i)", DOUBLE, null);
}

@Test
public void testArraySortDesc()
{
assertFunction("ARRAY_SORT_DESC(ARRAY [100, 1, 10, 50])", new ArrayType(INTEGER), ImmutableList.of(100, 50, 10, 1));
assertFunction("ARRAY_SORT_DESC(ARRAY [null, null, 100, 1, 10, 50])", new ArrayType(INTEGER), asList(100, 50, 10, 1, null, null));
assertFunction("ARRAY_SORT_DESC(ARRAY [double'1.0', double'2.0'])", new ArrayType(DOUBLE), ImmutableList.of(2.0d, 1.0d));
assertFunction("ARRAY_SORT_DESC(ARRAY [double'1.0', double'2.0'])", new ArrayType(DOUBLE), ImmutableList.of(2.0d, 1.0d));
assertFunction("ARRAY_SORT_DESC(ARRAY [null, double'-3.0', double'2.0', null])", new ArrayType(DOUBLE), asList(2.0d, -3.0d, null, null));
assertFunction("ARRAY_SORT_DESC(ARRAY ['a', 'bb', 'c'])", new ArrayType(createVarcharType(2)), ImmutableList.of("c", "bb", "a"));
assertFunction("ARRAY_SORT_DESC(ARRAY ['a', 'bb', 'c', null])", new ArrayType(createVarcharType(2)), asList("c", "bb", "a", null));
assertFunction("ARRAY_SORT_DESC(ARRAY [null, null, null])", new ArrayType(UNKNOWN), asList(null, null, null));
assertFunction("ARRAY_SORT_DESC(ARRAY [])", new ArrayType(UNKNOWN), emptyList());
assertFunction("ARRAY_SORT_DESC(null)", new ArrayType(UNKNOWN), null);
assertFunction("ARRAY_SORT_DESC(" +
"ARRAY [ARRAY['a'], ARRAY['b', 'b'], ARRAY['c']])",
new ArrayType(new ArrayType(createVarcharType(1))),
ImmutableList.of(ImmutableList.of("c"), ImmutableList.of("b", "b"), ImmutableList.of("a")));
assertFunction("ARRAY_SORT_DESC(" +
"ARRAY [ARRAY['a'], ARRAY['b', 'b'], ARRAY['c'], null, null, ARRAY['a', NULL]])",
new ArrayType(new ArrayType(createVarcharType(1))),
asList(singletonList("c"), ImmutableList.of("b", "b"), asList("a", null), singletonList("a"), null, null));
assertInvalidFunction("ARRAY_SORT_DESC(ARRAY [ROW('a', 1), ROW('a', null), null, ROW('a', 0)])", StandardErrorCode.INVALID_FUNCTION_ARGUMENT);
assertInvalidFunction("ARRAY_SORT_DESC(ARRAY [MAP(ARRAY['foo', 'bar'], ARRAY[1, 2]), MAP(ARRAY['foo', 'bar'], ARRAY[0, 3])])", SemanticErrorCode.FUNCTION_NOT_FOUND);
}
}

0 comments on commit 19a9bb8

Please sign in to comment.