diff --git a/presto-docs/src/main/sphinx/functions/map.rst b/presto-docs/src/main/sphinx/functions/map.rst index 6e9cd04f3ba4d..5079c3c3bc1f1 100644 --- a/presto-docs/src/main/sphinx/functions/map.rst +++ b/presto-docs/src/main/sphinx/functions/map.rst @@ -98,6 +98,10 @@ Map Functions SELECT map_top_n_keys(map(ARRAY['a', 'b', 'c'], ARRAY[1, 2, 3]), 2, (x, y) -> IF(x < y, -1, IF(x = y, 0, 1))) --- ['c', 'b'] +.. function:: map_remove_null_values(x(K,V)) -> map(K, V) + + Removes all the entries where the value is null from the map ``x``. + .. function:: map_normalize(x(varchar,double)) -> map(varchar,double) Returns the map with the same keys but all non-null values are scaled proportionally so that the sum of values becomes 1. diff --git a/presto-main/src/main/java/com/facebook/presto/operator/scalar/sql/MapSqlFunctions.java b/presto-main/src/main/java/com/facebook/presto/operator/scalar/sql/MapSqlFunctions.java index 2c19dc7d762aa..bae51afa0b141 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/scalar/sql/MapSqlFunctions.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/scalar/sql/MapSqlFunctions.java @@ -67,4 +67,15 @@ public static String mapTopNValuesComparator() { return "RETURN IF(n < 0, fail('n must be greater than or equal to 0'), slice(reverse(array_sort(remove_nulls(map_values(input)), f)) || filter(map_values(input), x -> x is null), 1, n))"; } + + @SqlInvokedScalarFunction(value = "map_remove_null_values", deterministic = true, calledOnNullInput = true) + @Description("Constructs a map by removing all the keys with null values.") + @TypeParameter("K") + @TypeParameter("V") + @SqlParameter(name = "input", type = "map(K, V)") + @SqlType("map(K, V)") + public static String mapRemoveNulls() + { + return "RETURN map_filter(input, (k, v) -> v is not null)"; + } } diff --git a/presto-main/src/test/java/com/facebook/presto/operator/scalar/sql/TestMapRemoveNullValuesFunction.java b/presto-main/src/test/java/com/facebook/presto/operator/scalar/sql/TestMapRemoveNullValuesFunction.java new file mode 100644 index 0000000000000..59cc98ca0059c --- /dev/null +++ b/presto-main/src/test/java/com/facebook/presto/operator/scalar/sql/TestMapRemoveNullValuesFunction.java @@ -0,0 +1,65 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.operator.scalar.sql; + +import com.facebook.presto.common.type.ArrayType; +import com.facebook.presto.operator.scalar.AbstractTestFunctions; +import com.google.common.collect.ImmutableMap; +import org.testng.annotations.Test; + +import static com.facebook.presto.common.type.DecimalType.createDecimalType; +import static com.facebook.presto.common.type.IntegerType.INTEGER; +import static com.facebook.presto.common.type.UnknownType.UNKNOWN; +import static com.facebook.presto.common.type.VarcharType.createVarcharType; +import static com.facebook.presto.util.StructuralTestUtil.mapType; +import static java.util.Collections.emptyList; +import static java.util.Collections.emptyMap; +import static java.util.Collections.singletonList; + +public class TestMapRemoveNullValuesFunction + extends AbstractTestFunctions +{ + @Test + public void test() + { + assertFunction( + "MAP_REMOVE_NULL_VALUES(MAP(ARRAY[1, 2, 3], ARRAY[4, 5, 6]))", + mapType(INTEGER, INTEGER), + ImmutableMap.of(1, 4, 2, 5, 3, 6)); + assertFunction( + "MAP_REMOVE_NULL_VALUES(MAP(ARRAY[-1, -2, -3], ARRAY[null, 5.0, null]))", + mapType(INTEGER, createDecimalType(2, 1)), + ImmutableMap.of(-2, decimal("5.0"))); + assertFunction( + "MAP_REMOVE_NULL_VALUES(MAP(ARRAY['ab', 'bc', 'cd'], ARRAY[null, null, null]))", + mapType(createVarcharType(2), UNKNOWN), + emptyMap()); + assertFunction( + "MAP_REMOVE_NULL_VALUES(MAP(ARRAY[123.0, 99.5, 1000.99], ARRAY['x', 'y', 'z']))", + mapType(createDecimalType(6, 2), createVarcharType(1)), + ImmutableMap.of(decimal("123.00"), "x", decimal("99.50"), "y", decimal("1000.99"), "z")); + assertFunction( + "MAP_REMOVE_NULL_VALUES(MAP(ARRAY['a', 'b', 'c'], ARRAY[ARRAY[1], ARRAY[], ARRAY[null]]))", + mapType(createVarcharType(1), new ArrayType(INTEGER)), + ImmutableMap.of("a", singletonList(1), "b", emptyList(), "c", singletonList(null))); + assertFunction( + "MAP_REMOVE_NULL_VALUES(MAP(ARRAY[], ARRAY[]))", + mapType(UNKNOWN, UNKNOWN), + emptyMap()); + assertFunction( + "MAP_REMOVE_NULL_VALUES(null)", + mapType(UNKNOWN, UNKNOWN), + null); + } +}