Skip to content

Commit

Permalink
Adding support of MAP_TOP_N_KEYS to Presto
Browse files Browse the repository at this point in the history
  • Loading branch information
jainavi17 authored and Sreeni Viswanadha committed Nov 8, 2022
1 parent 87c43b2 commit 05fcfbf
Show file tree
Hide file tree
Showing 5 changed files with 235 additions and 0 deletions.
17 changes: 17 additions & 0 deletions presto-docs/src/main/sphinx/functions/map.rst
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,23 @@ Map Functions

Returns all the keys in the map ``x``.

.. function:: map_top_n_keys(x(K,V), n) -> array(K)

Returns top n keys in the map ``x``.
``n`` must be a non-negative integer
For bottom ``n`` keys, use the function with lambda operator to perform custom sorting

SELECT map_top_n_keys(map(ARRAY['a', 'b', 'c'], ARRAY[1, 2, 3]), 2) --- ['c', 'b']

.. function:: map_top_n_keys(x(K,V), n, function(K,K,int)) -> array(K)

Returns top n keys in the map ``x`` based on the given comparator ``function``. The comparator will take
two non-nullable arguments representing two keys of the ``map``. It returns -1, 0, or 1
as the first key is less than, equal to, or greater than the second key.
If the comparator function returns other values (including ``NULL``), the query will fail and raise an error ::

SELECT map_top_n_keys(map(ARRAY['a', 'b', 'c'], ARRAY[1, 2, 3]), 2, (x, y) -> IF(x < y, -1, IF(x = y, 0, 1))) --- ['c', 'b']

.. function:: map_normalize(x(varchar,double)) -> map(varchar,double)

Returns the map with the same keys but all non-null values are scaled proportionally so that the sum of values becomes 1.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,7 @@
import com.facebook.presto.operator.scalar.WordStemFunction;
import com.facebook.presto.operator.scalar.sql.ArraySqlFunctions;
import com.facebook.presto.operator.scalar.sql.MapNormalizeFunction;
import com.facebook.presto.operator.scalar.sql.MapSqlFunctions;
import com.facebook.presto.operator.scalar.sql.SimpleSamplingPercent;
import com.facebook.presto.operator.window.CumulativeDistributionFunction;
import com.facebook.presto.operator.window.DenseRankFunction;
Expand Down Expand Up @@ -897,6 +898,7 @@ private List<? extends SqlFunction> getBuildInFunctions(FeaturesConfig featuresC
.sqlInvokedScalar(MapNormalizeFunction.class)
.sqlInvokedScalars(ArraySqlFunctions.class)
.sqlInvokedScalars(ArrayIntersectFunction.class)
.sqlInvokedScalars(MapSqlFunctions.class)
.sqlInvokedScalars(SimpleSamplingPercent.class)
.scalar(DynamicFilterPlaceholderFunction.class)
.scalars(EnumCasts.class)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.facebook.presto.operator.scalar.sql;

import com.facebook.presto.spi.function.Description;
import com.facebook.presto.spi.function.SqlInvokedScalarFunction;
import com.facebook.presto.spi.function.SqlParameter;
import com.facebook.presto.spi.function.SqlParameters;
import com.facebook.presto.spi.function.SqlType;
import com.facebook.presto.spi.function.TypeParameter;

public class MapSqlFunctions
{
private MapSqlFunctions() {}

@SqlInvokedScalarFunction(value = "map_top_n_keys", deterministic = true, calledOnNullInput = false)
@Description("Returns the top N keys of the given map in descending order according to the natural ordering of its values.")
@TypeParameter("K")
@TypeParameter("V")
@SqlParameters({@SqlParameter(name = "input", type = "map(K, V)"), @SqlParameter(name = "n", type = "bigint")})
@SqlType("array<K>")
public static String mapTopNKeys()
{
return "RETURN IF(n < 0, fail('n must be greater than or equal to 0'), slice(reverse(array_sort(map_keys(input))), 1, n))";
}

@SqlInvokedScalarFunction(value = "map_top_n_keys", deterministic = true, calledOnNullInput = true)
@Description("Returns the top N keys of the given map sorted using the provided lambda comparator.")
@TypeParameter("K")
@TypeParameter("V")
@SqlParameters({@SqlParameter(name = "input", type = "map(K, V)"), @SqlParameter(name = "n", type = "bigint"), @SqlParameter(name = "f", type = "function(K, K, int)")})
@SqlType("array<K>")
public static String mapTopNKeysComparator()
{
return "RETURN IF(n < 0, fail('n must be greater than or equal to 0'), slice(reverse(array_sort(map_keys(input), f)), 1, n))";
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.facebook.presto.operator.scalar;

import com.facebook.presto.common.type.ArrayType;
import com.facebook.presto.common.type.RowType;
import com.facebook.presto.spi.StandardErrorCode;
import com.facebook.presto.sql.analyzer.SemanticErrorCode;
import com.google.common.collect.ImmutableList;
import org.testng.annotations.Test;

import static com.facebook.presto.common.type.IntegerType.INTEGER;
import static com.facebook.presto.common.type.UnknownType.UNKNOWN;
import static com.facebook.presto.common.type.VarcharType.createVarcharType;

public class TestMapTopNKeysComparatorFunction
extends AbstractTestFunctions
{
@Test
public void testBasic()
{
assertFunction("MAP_TOP_N_KEYS(MAP(ARRAY[4, 5, 6], ARRAY[1, 2, 3]), 2, (x, y) -> IF(x < y, 1, IF(x = y, 0, -1)))", new ArrayType(INTEGER), ImmutableList.of(4, 5));
assertFunction("MAP_TOP_N_KEYS(MAP(ARRAY[4, 5, 6], ARRAY[1, 2, 3]), 2, (x, y) -> IF(x > y, 1, IF(x = y, 0, -1)))", new ArrayType(INTEGER), ImmutableList.of(6, 5));
assertFunction("MAP_TOP_N_KEYS(MAP(ARRAY['x', 'y', 'z'], ARRAY[1, 2, 3]), 3, (x, y) -> IF(x > y, 1, IF(x = y, 0, -1)))", new ArrayType(createVarcharType(1)), ImmutableList.of("z", "y", "x"));
}

@Test
public void testNLargerThanMapSize()
{
assertFunction("MAP_TOP_N_KEYS(MAP(ARRAY[4, 5, 6], ARRAY[1, 2, 3]), 8, (x, y) -> IF(x < y, 1, IF(x = y, 0, -1)))", new ArrayType(INTEGER), ImmutableList.of(4, 5, 6));
assertFunction("MAP_TOP_N_KEYS(MAP(ARRAY[4, 5, 6], ARRAY[1, 2, 3]), 9, (x, y) -> IF(x > y, 1, IF(x = y, 0, -1)))", new ArrayType(INTEGER), ImmutableList.of(6, 5, 4));
assertFunction("MAP_TOP_N_KEYS(MAP(ARRAY['x', 'y', 'z'], ARRAY[1, 2, 3]), 10, (x, y) -> IF(x > y, 1, IF(x = y, 0, -1)))", new ArrayType(createVarcharType(1)), ImmutableList.of("z", "y", "x"));
assertFunction("MAP_TOP_N_KEYS(MAP(ARRAY['abc', 'ab', 'a', 'b'], ARRAY[1, 2, 3, 4]), 10, (x, y) -> CASE " +
"WHEN LENGTH(x) > LENGTH(y) THEN 1 " +
"WHEN LENGTH(x) < LENGTH(y) THEN -1 " +
"WHEN x > y THEN 1 " +
"WHEN x < y THEN -1 " +
"ELSE -1 END)",
new ArrayType(createVarcharType(3)), ImmutableList.of("abc", "ab", "b", "a"));
}

@Test
public void testNegativeN()
{
assertInvalidFunction("MAP_TOP_N_KEYS(MAP(ARRAY[4, 5, 6], ARRAY[1, 2, 3]), -1, (x, y) -> IF(x < y, 1, IF(x = y, 0, -1)))", StandardErrorCode.GENERIC_USER_ERROR, "n must be greater than or equal to 0");
assertInvalidFunction("MAP_TOP_N_KEYS(MAP(ARRAY[4, 5, 6], ARRAY[1, 2, 3]), -2, (x, y) -> IF(x > y, 1, IF(x = y, 0, -1)))", StandardErrorCode.GENERIC_USER_ERROR, "n must be greater than or equal to 0");
assertInvalidFunction("MAP_TOP_N_KEYS(MAP(ARRAY['x', 'y', 'z'], ARRAY[1, 2, 3]), -3, (x, y) -> IF(x > y, 1, IF(x = y, 0, -1)))", StandardErrorCode.GENERIC_USER_ERROR, "n must be greater than or equal to 0");
}

@Test
public void testZeroN()
{
assertFunction("MAP_TOP_N_KEYS(MAP(ARRAY[4, 5, 6], ARRAY[1, 2, 3]), 0, (x, y) -> IF(x < y, 1, IF(x = y, 0, -1)))", new ArrayType(INTEGER), ImmutableList.of());
assertFunction("MAP_TOP_N_KEYS(MAP(ARRAY[4, 5, 6], ARRAY[1, 2, 3]), 0, (x, y) -> IF(x > y, 1, IF(x = y, 0, -1)))", new ArrayType(INTEGER), ImmutableList.of());
assertFunction("MAP_TOP_N_KEYS(MAP(ARRAY['x', 'y', 'z'], ARRAY[1, 2, 3]), 0, (x, y) -> IF(x > y, 1, IF(x = y, 0, -1)))", new ArrayType(createVarcharType(1)), ImmutableList.of());
}

@Test
public void testEmpty()
{
assertFunction("MAP_TOP_N_KEYS(MAP(ARRAY[], ARRAY[]), 1, (x, y) -> IF(x < y, 1, IF(x = y, 0, -1)))", new ArrayType(UNKNOWN), ImmutableList.of());
}

@Test
public void testNull()
{
assertFunction("MAP_TOP_N_KEYS(NULL, 1, (x, y) -> IF(x < y, 1, IF(x = y, 0, -1)))", new ArrayType(UNKNOWN), null);
}

@Test
public void testComplexKeys()
{
assertFunction("MAP_TOP_N_KEYS(MAP(ARRAY[ROW('x', 1), ROW('y', 2), ROW('z', 3)], ARRAY[1, 2, 3]), 3, (x, y) -> IF(x[1] < y[1], 1, IF(x[1] = y[1], 0, -1)))", new ArrayType(RowType.from(ImmutableList.of(RowType.field(createVarcharType(1)), RowType.field(INTEGER)))), ImmutableList.of(ImmutableList.of("x", 1), ImmutableList.of("y", 2), ImmutableList.of("z", 3)));
}

@Test
public void testBadLambda()
{
assertInvalidFunction("MAP_TOP_N_KEYS(MAP(ARRAY[4, 5, 6], ARRAY[1, 2, 3]), 1, (x, y) -> 10)", StandardErrorCode.INVALID_FUNCTION_ARGUMENT, "Lambda comparator must return either -1, 0, or 1");
assertInvalidFunction("MAP_TOP_N_KEYS(MAP(ARRAY[4, 5, 6], ARRAY[1, 2, 3]), 2, null)", SemanticErrorCode.FUNCTION_NOT_FOUND);
assertInvalidFunction("MAP_TOP_N_KEYS(MAP(ARRAY[4, 5, 6], ARRAY[1, 2, 3]), 3, (x, y) -> IF(x = 'test', 1, -1))", SemanticErrorCode.TYPE_MISMATCH);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.facebook.presto.operator.scalar.sql;

import com.facebook.presto.common.type.ArrayType;
import com.facebook.presto.common.type.RowType;
import com.facebook.presto.operator.scalar.AbstractTestFunctions;
import com.facebook.presto.spi.StandardErrorCode;
import com.google.common.collect.ImmutableList;
import org.testng.annotations.Test;

import static com.facebook.presto.common.type.DecimalType.createDecimalType;
import static com.facebook.presto.common.type.IntegerType.INTEGER;
import static com.facebook.presto.common.type.UnknownType.UNKNOWN;
import static com.facebook.presto.common.type.VarcharType.createVarcharType;

public class TestMapTopNKeysFunction
extends AbstractTestFunctions
{
@Test
public void testBasic()
{
assertFunction("MAP_TOP_N_KEYS(MAP(ARRAY[1, 2, 3], ARRAY[4, 5, 6]), 2)", new ArrayType(INTEGER), ImmutableList.of(3, 2));
assertFunction("MAP_TOP_N_KEYS(MAP(ARRAY[-1, -2, -3], ARRAY[4, 5, 6]), 2)", new ArrayType(INTEGER), ImmutableList.of(-1, -2));
assertFunction("MAP_TOP_N_KEYS(MAP(ARRAY['ab', 'bc', 'cd'], ARRAY['x', 'y', 'z']), 1)", new ArrayType(createVarcharType(2)), ImmutableList.of("cd"));
assertFunction("MAP_TOP_N_KEYS(MAP(ARRAY[123.0, 99.5, 1000.99], ARRAY['x', 'y', 'z']), 3)", new ArrayType(createDecimalType(6, 2)), ImmutableList.of(decimal("1000.99"), decimal("123.00"), decimal("99.50")));
}

@Test
public void testNegativeN()
{
assertInvalidFunction("MAP_TOP_N_KEYS(MAP(ARRAY[100, 200, 300], ARRAY[4, 5, 6]), -3)", StandardErrorCode.GENERIC_USER_ERROR, "n must be greater than or equal to 0");
assertInvalidFunction("MAP_TOP_N_KEYS(MAP(ARRAY[1, 2, 3], ARRAY[4, 5, 6]), -1)", StandardErrorCode.GENERIC_USER_ERROR, "n must be greater than or equal to 0");
assertInvalidFunction("MAP_TOP_N_KEYS(MAP(ARRAY['a', 'b', 'c'], ARRAY[4, 5, 6]), -2)", StandardErrorCode.GENERIC_USER_ERROR, "n must be greater than or equal to 0");
}

@Test
public void testZeroN()
{
assertFunction("MAP_TOP_N_KEYS(MAP(ARRAY[-1, -2, -3], ARRAY[4, 5, 6]), 0)", new ArrayType(INTEGER), ImmutableList.of());
assertFunction("MAP_TOP_N_KEYS(MAP(ARRAY['ab', 'bc', 'cd'], ARRAY['x', 'y', 'z']), 0)", new ArrayType(createVarcharType(2)), ImmutableList.of());
assertFunction("MAP_TOP_N_KEYS(MAP(ARRAY[123.0, 99.5, 1000.99], ARRAY['x', 'y', 'z']), 0)", new ArrayType(createDecimalType(6, 2)), ImmutableList.of());
}

@Test
public void testEmpty()
{
assertFunction("MAP_TOP_N_KEYS(MAP(ARRAY[], ARRAY[]), 5)", new ArrayType(UNKNOWN), ImmutableList.of());
}

@Test
public void testNull()
{
assertFunction("MAP_TOP_N_KEYS(NULL, 1)", new ArrayType(UNKNOWN), null);
}

@Test
public void testComplexKeys()
{
assertFunction("MAP_TOP_N_KEYS(MAP(ARRAY[ROW('x', 1), ROW('y', 2)], ARRAY[1, 2]), 1)", new ArrayType(RowType.from(ImmutableList.of(RowType.field(createVarcharType(1)), RowType.field(INTEGER)))), ImmutableList.of(ImmutableList.of("y", 2)));
assertFunction("MAP_TOP_N_KEYS(MAP(ARRAY[ROW('x', 1), ROW('x', -2)], ARRAY[1, 2]), 1)", new ArrayType(RowType.from(ImmutableList.of(RowType.field(createVarcharType(1)), RowType.field(INTEGER)))), ImmutableList.of(ImmutableList.of("x", 1)));
}
}

0 comments on commit 05fcfbf

Please sign in to comment.