Skip to content

Commit

Permalink
Generalize existing SQL UDF functions to all types
Browse files Browse the repository at this point in the history
  • Loading branch information
Sacha Viscaino authored and Sreeni Viswanadha committed Oct 31, 2022
1 parent 816187a commit e52a7f1
Show file tree
Hide file tree
Showing 5 changed files with 92 additions and 119 deletions.
14 changes: 2 additions & 12 deletions presto-docs/src/main/sphinx/functions/array.rst
Original file line number Diff line number Diff line change
Expand Up @@ -48,18 +48,11 @@ Array Functions

Returns a set of elements that occur more than once in ``array``.

``T`` must be coercible to ``bigint`` or ``varchar``.

.. function:: array_except(x, y) -> array

Returns an array of elements in ``x`` but not in ``y``, without duplicates.

.. function:: array_frequency(array(bigint)) -> map(bigint, int)

Returns a map: keys are the unique elements in the ``array``, values are how many times the key appears.
Ignores null elements. Empty array returns empty map.

.. function:: array_frequency(array(varchar)) -> map(varchar, int)
.. function:: array_frequency(array(E)) -> map(E, int)

Returns a map: keys are the unique elements in the ``array``, values are how many times the key appears.
Ignores null elements. Empty array returns empty map.
Expand All @@ -68,16 +61,13 @@ Array Functions

Returns a boolean: whether ``array`` has any elements that occur more than once.

``T`` must be coercible to ``bigint`` or ``varchar``.

.. function:: array_intersect(x, y) -> array

Returns an array of the elements in the intersection of ``x`` and ``y``, without duplicates.

.. function:: array_intersect(array(array(E))) -> array(bigint/double)
.. function:: array_intersect(array(array(E))) -> array(E)

Returns an array of the elements in the intersection of all arrays in the given array, without duplicates.
E must be coercible to ``double``. Returns ``bigint`` if T is coercible to ``bigint``. Otherwise, returns ``double``.

.. function:: array_join(x, delimiter, null_replacement) -> varchar

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,18 +57,10 @@ public static Block intersect(

@SqlInvokedScalarFunction(value = "array_intersect", deterministic = true, calledOnNullInput = false)
@Description("Intersects elements of all arrays in the given array")
@SqlParameter(name = "input", type = "array<array<bigint>>")
@SqlType("array<bigint>")
public static String arrayIntersectBigint()
{
return "RETURN reduce(input, null, (s, x) -> IF((s IS NULL), x, array_intersect(s, x)), (s) -> s)";
}

@SqlInvokedScalarFunction(value = "array_intersect", deterministic = true, calledOnNullInput = false)
@Description("Intersects elements of all arrays in the given array")
@SqlParameter(name = "input", type = "array<array<double>>")
@SqlType("array<double>")
public static String arrayIntersectDouble()
@TypeParameter("T")
@SqlParameter(name = "input", type = "array<array<T>>")
@SqlType("array<T>")
public static String arrayIntersectArray()
{
return "RETURN reduce(input, null, (s, x) -> IF((s IS NULL), x, array_intersect(s, x)), (s) -> s)";
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import com.facebook.presto.spi.function.SqlInvokedScalarFunction;
import com.facebook.presto.spi.function.SqlParameter;
import com.facebook.presto.spi.function.SqlType;
import com.facebook.presto.spi.function.TypeParameter;

public class ArraySqlFunctions
{
Expand Down Expand Up @@ -55,8 +56,9 @@ public static String arrayAverage()

@SqlInvokedScalarFunction(value = "array_frequency", deterministic = true, calledOnNullInput = false)
@Description("Returns the frequency of all array elements as a map.")
@SqlParameter(name = "input", type = "array(bigint)")
@SqlType("map(bigint, int)")
@TypeParameter("T")
@SqlParameter(name = "input", type = "array(T)")
@SqlType("map(T, int)")
public static String arrayFrequencyBigint()
{
return "RETURN reduce(" +
Expand All @@ -66,56 +68,25 @@ public static String arrayFrequencyBigint()
"m -> m)";
}

@SqlInvokedScalarFunction(value = "array_frequency", deterministic = true, calledOnNullInput = false)
@Description("Returns the frequency of all array elements as a map.")
@SqlParameter(name = "input", type = "array(varchar)")
@SqlType("map(varchar, int)")
public static String arrayFrequencyVarchar()
{
return "RETURN reduce(" +
"input," +
"MAP()," +
"(m, x) -> IF (x IS NOT NULL, MAP_CONCAT(m,MAP_FROM_ENTRIES(ARRAY[ROW(x, COALESCE(ELEMENT_AT(m,x) + 1, 1))])), m)," +
"m -> m)";
}

@SqlInvokedScalarFunction(value = "array_duplicates", alias = {"array_dupes"}, deterministic = true, calledOnNullInput = false)
@Description("Returns set of elements that have duplicates")
@SqlParameter(name = "input", type = "array(varchar)")
@SqlType("array(varchar)")
public static String arrayDuplicatesVarchar()
{
return "RETURN CONCAT(" +
"CAST(IF (cardinality(filter(input, x -> x is NULL)) > 1, ARRAY[NULL], ARRAY[]) AS ARRAY(VARCHAR))," +
"map_keys(map_filter(array_frequency(input), (k, v) -> v > 1)))";
}

@SqlInvokedScalarFunction(value = "array_duplicates", alias = {"array_dupes"}, deterministic = true, calledOnNullInput = false)
@Description("Returns set of elements that have duplicates")
@SqlParameter(name = "input", type = "array(bigint)")
@SqlType("array(bigint)")
public static String arrayDuplicatesBigint()
@SqlParameter(name = "input", type = "array(T)")
@TypeParameter("T")
@SqlType("array(T)")
public static String arrayDuplicates()
{
return "RETURN CONCAT(" +
"CAST(IF (cardinality(filter(input, x -> x is NULL)) > 1, ARRAY[NULL], ARRAY[]) AS ARRAY(BIGINT))," +
"IF (cardinality(filter(input, x -> x is NULL)) > 1, array[find_first(input, x -> x IS NULL)], array[])," +
"map_keys(map_filter(array_frequency(input), (k, v) -> v > 1)))";
}

@SqlInvokedScalarFunction(value = "array_has_duplicates", alias = {"array_has_dupes"}, deterministic = true, calledOnNullInput = false)
@Description("Returns whether array has any duplicate element")
@SqlParameter(name = "input", type = "array(varchar)")
@TypeParameter("T")
@SqlParameter(name = "input", type = "array(T)")
@SqlType("boolean")
public static String arrayHasDuplicatesVarchar()
{
return "RETURN cardinality(array_duplicates(input)) > 0";
}

@SqlInvokedScalarFunction(value = "array_has_duplicates", alias = {"array_has_dupes"}, deterministic = true, calledOnNullInput = false)
@Description("Returns whether array has any duplicate element")
@SqlParameter(name = "input", type = "array(bigint)")
@SqlType("boolean")
public static String arrayHasDuplicatesBigint()
{
return "RETURN cardinality(array_duplicates(input)) > 0";
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
package com.facebook.presto.operator.scalar;

import com.facebook.presto.common.type.ArrayType;
import com.facebook.presto.common.type.RowType;
import com.google.common.collect.ImmutableList;
import org.testng.annotations.Test;

Expand Down Expand Up @@ -69,9 +70,15 @@ public void testDuplicates()
@Test
public void testSQLFunctions()
{
assertFunction("array_intersect(ARRAY[ARRAY[1, 3, 5], ARRAY[2, 3, 5], ARRAY[3, 3, 3, 6]])", new ArrayType(BIGINT), ImmutableList.of(3L));
assertFunction("array_intersect(ARRAY[ARRAY[], ARRAY[1, 2, 3]])", new ArrayType(BIGINT), ImmutableList.of());
assertFunction("array_intersect(ARRAY[ARRAY[1, 2, 3], null])", new ArrayType(BIGINT), null);
assertFunction("array_intersect(ARRAY[ARRAY[1.1, 2.2, 3.3], ARRAY[1.1, 3.4], ARRAY[1.0, 1.1, 1.2]])", new ArrayType(DOUBLE), ImmutableList.of(1.1));
assertFunction("array_intersect(ARRAY[ARRAY[1, 3, 5], ARRAY[2, 3, 5], ARRAY[3, 3, 3, 6]])", new ArrayType(INTEGER), ImmutableList.of(3));
assertFunction("array_intersect(ARRAY[ARRAY[], ARRAY[1, 2, 3]])", new ArrayType(INTEGER), ImmutableList.of());
assertFunction("array_intersect(ARRAY[ARRAY[1, 2, 3], null])", new ArrayType(INTEGER), null);
assertFunction("array_intersect(ARRAY[ARRAY[DOUBLE'1.1', DOUBLE'2.2', DOUBLE'3.3'], ARRAY[DOUBLE'1.1', DOUBLE'3.4'], ARRAY[DOUBLE'1.0', DOUBLE'1.1', DOUBLE'1.2']])", new ArrayType(DOUBLE), ImmutableList.of(1.1));

assertFunction("array_intersect(ARRAY[ARRAY[ARRAY[1], ARRAY[2]], ARRAY[ARRAY[2], ARRAY[3]]])", new ArrayType(new ArrayType(INTEGER)), ImmutableList.of(ImmutableList.of(2)));

RowType rowType = RowType.from(ImmutableList.of(RowType.field("x", DOUBLE), RowType.field("y", DOUBLE)));
String t = rowType.toString();
assertFunction("array_intersect(ARRAY[ARRAY[CAST((1.0, 2.0) AS " + t + "), CAST((2.0, 3.0) AS " + t + ")], ARRAY[CAST((0.0, 1.0) AS " + t + "), CAST((1.0, 2.0) AS " + t + ")]])", new ArrayType(rowType), ImmutableList.of(ImmutableList.of(1.0, 2.0)));
}
}
Loading

0 comments on commit e52a7f1

Please sign in to comment.