diff --git a/presto-main/src/main/java/com/facebook/presto/SystemSessionProperties.java b/presto-main/src/main/java/com/facebook/presto/SystemSessionProperties.java index 0e13935cf7a98..3ee55ba73504e 100644 --- a/presto-main/src/main/java/com/facebook/presto/SystemSessionProperties.java +++ b/presto-main/src/main/java/com/facebook/presto/SystemSessionProperties.java @@ -315,6 +315,7 @@ public final class SystemSessionProperties public static final String USE_PARTIAL_AGGREGATION_HISTORY = "use_partial_aggregation_history"; public static final String TRACK_PARTIAL_AGGREGATION_HISTORY = "track_partial_aggregation_history"; public static final String REMOVE_REDUNDANT_CAST_TO_VARCHAR_IN_JOIN = "remove_redundant_cast_to_varchar_in_join"; + public static final String REMOVE_MAP_CAST = "remove_map_cast"; public static final String HANDLE_COMPLEX_EQUI_JOINS = "handle_complex_equi_joins"; public static final String SKIP_HASH_GENERATION_FOR_JOIN_WITH_TABLE_SCAN_INPUT = "skip_hash_generation_for_join_with_table_scan_input"; public static final String GENERATE_DOMAIN_FILTERS = "generate_domain_filters"; @@ -1907,6 +1908,11 @@ public SystemSessionProperties( "If both left and right side of join clause are varchar cast from int/bigint, remove the cast here", featuresConfig.isRemoveRedundantCastToVarcharInJoin(), false), + booleanProperty( + REMOVE_MAP_CAST, + "Remove map cast when possible", + false, + false), booleanProperty( HANDLE_COMPLEX_EQUI_JOINS, "Handle complex equi-join conditions to open up join space for join reordering", @@ -3182,6 +3188,11 @@ public static boolean isRemoveRedundantCastToVarcharInJoinEnabled(Session sessio return session.getSystemProperty(REMOVE_REDUNDANT_CAST_TO_VARCHAR_IN_JOIN, Boolean.class); } + public static boolean isRemoveMapCastEnabled(Session session) + { + return session.getSystemProperty(REMOVE_MAP_CAST, Boolean.class); + } + public static boolean shouldHandleComplexEquiJoins(Session session) { return session.getSystemProperty(HANDLE_COMPLEX_EQUI_JOINS, Boolean.class); diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/PlanOptimizers.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/PlanOptimizers.java index 37bc1cc88b8c0..269f89793e85f 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/PlanOptimizers.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/PlanOptimizers.java @@ -99,6 +99,7 @@ import com.facebook.presto.sql.planner.iterative.rule.RemoveEmptyDelete; import com.facebook.presto.sql.planner.iterative.rule.RemoveFullSample; import com.facebook.presto.sql.planner.iterative.rule.RemoveIdentityProjectionsBelowProjection; +import com.facebook.presto.sql.planner.iterative.rule.RemoveMapCastRule; import com.facebook.presto.sql.planner.iterative.rule.RemoveRedundantAggregateDistinct; import com.facebook.presto.sql.planner.iterative.rule.RemoveRedundantCastToVarcharInJoinClause; import com.facebook.presto.sql.planner.iterative.rule.RemoveRedundantDistinct; @@ -496,7 +497,8 @@ public PlanOptimizers( ruleStats, statsCalculator, estimatedExchangesCostCalculator, - ImmutableSet.of(new RemoveRedundantCastToVarcharInJoinClause(metadata.getFunctionAndTypeManager())))); + ImmutableSet.>builder().add(new RemoveRedundantCastToVarcharInJoinClause(metadata.getFunctionAndTypeManager())) + .addAll(new RemoveMapCastRule(metadata.getFunctionAndTypeManager()).rules()).build())); builder.add(new IterativeOptimizer( metadata, diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/RemoveMapCastRule.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/RemoveMapCastRule.java new file mode 100644 index 0000000000000..428640b5fc556 --- /dev/null +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/RemoveMapCastRule.java @@ -0,0 +1,133 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.sql.planner.iterative.rule; + +import com.facebook.presto.Session; +import com.facebook.presto.common.type.MapType; +import com.facebook.presto.common.type.Type; +import com.facebook.presto.expressions.RowExpressionRewriter; +import com.facebook.presto.expressions.RowExpressionTreeRewriter; +import com.facebook.presto.metadata.FunctionAndTypeManager; +import com.facebook.presto.spi.relation.CallExpression; +import com.facebook.presto.spi.relation.RowExpression; +import com.facebook.presto.sql.planner.iterative.Rule; +import com.facebook.presto.sql.relational.FunctionResolution; +import com.google.common.collect.ImmutableSet; + +import java.util.Set; + +import static com.facebook.presto.SystemSessionProperties.isRemoveMapCastEnabled; +import static com.facebook.presto.common.function.OperatorType.SUBSCRIPT; +import static com.facebook.presto.common.type.BigintType.BIGINT; +import static com.facebook.presto.common.type.IntegerType.INTEGER; +import static com.facebook.presto.sql.relational.Expressions.call; +import static com.facebook.presto.sql.relational.Expressions.castToInteger; +import static com.facebook.presto.sql.relational.Expressions.tryCast; +import static java.util.Objects.requireNonNull; + +/** + * Remove cast on map if possible. Currently it only supports subscript and element_at function, and only works when map key is of type integer and index is bigint. For example: + * Input: cast(feature as map)[key], where feature is of type map and key is of type bigint + * Output: feature[cast(key as integer)] + * + * Input: element_at(cast(feature as map), key), where feature is of type map and key is of type bigint + * Output: element_at(feature, try_cast(key as integer)) + * + * Notice that here when it's accessing the map using subscript function, we use CAST function in index, and when it's element_at function, we use TRY_CAST function, so that + * when the key is out of integer range, for feature[key] it will fail both with and without optimization, fail with map key not exists before optimization and with cast failure after optimization + * when the key is out of integer range, for element_at(feature, key) it will return NULL both before and after optimization + */ +public class RemoveMapCastRule + extends RowExpressionRewriteRuleSet +{ + public RemoveMapCastRule(FunctionAndTypeManager functionAndTypeManager) + { + super(new RemoveMapCastRule.Rewriter(functionAndTypeManager)); + } + + @Override + public boolean isRewriterEnabled(Session session) + { + return isRemoveMapCastEnabled(session); + } + + @Override + public Set> rules() + { + return ImmutableSet.of(filterRowExpressionRewriteRule(), projectRowExpressionRewriteRule()); + } + + private static class Rewriter + implements PlanRowExpressionRewriter + { + private final RemoveMapCastRewriter removeMapCastRewriter; + + public Rewriter(FunctionAndTypeManager functionAndTypeManager) + { + requireNonNull(functionAndTypeManager, "functionAndTypeManager is null"); + this.removeMapCastRewriter = new RemoveMapCastRewriter(functionAndTypeManager); + } + + @Override + public RowExpression rewrite(RowExpression expression, Rule.Context context) + { + return RowExpressionTreeRewriter.rewriteWith(removeMapCastRewriter, expression); + } + } + + private static class RemoveMapCastRewriter + extends RowExpressionRewriter + { + private final FunctionAndTypeManager functionAndTypeManager; + private final FunctionResolution functionResolution; + + private RemoveMapCastRewriter(FunctionAndTypeManager functionAndTypeManager) + { + this.functionAndTypeManager = functionAndTypeManager; + this.functionResolution = new FunctionResolution(functionAndTypeManager.getFunctionAndTypeResolver()); + } + + @Override + public RowExpression rewriteCall(CallExpression node, Void context, RowExpressionTreeRewriter treeRewriter) + { + if ((functionResolution.isSubscriptFunction(node.getFunctionHandle()) || functionResolution.isElementAtFunction(node.getFunctionHandle())) && node.getArguments().get(0) instanceof CallExpression + && functionResolution.isCastFunction(((CallExpression) node.getArguments().get(0)).getFunctionHandle()) + && ((CallExpression) node.getArguments().get(0)).getArguments().get(0).getType() instanceof MapType) { + CallExpression castExpression = (CallExpression) node.getArguments().get(0); + RowExpression castInput = castExpression.getArguments().get(0); + Type fromKeyType = ((MapType) castInput.getType()).getKeyType(); + Type fromValueType = ((MapType) castInput.getType()).getValueType(); + Type toKeyType = ((MapType) castExpression.getType()).getKeyType(); + Type toValueType = ((MapType) castExpression.getType()).getValueType(); + + if (canRemoveMapCast(fromKeyType, fromValueType, toKeyType, toValueType, node.getArguments().get(1).getType())) { + if (functionResolution.isSubscriptFunction(node.getFunctionHandle())) { + RowExpression newIndex = castToInteger(functionAndTypeManager, node.getArguments().get(1)); + return call(SUBSCRIPT.name(), functionResolution.subscriptFunction(castInput.getType(), newIndex.getType()), node.getType(), castInput, newIndex); + } + else { + RowExpression newIndex = tryCast(functionAndTypeManager, node.getArguments().get(1), INTEGER); + return call(functionAndTypeManager, "element_at", node.getType(), castInput, newIndex); + } + } + } + return null; + } + + private static boolean canRemoveMapCast(Type fromKeyType, Type fromValueType, Type toKeyType, Type toValueType, Type indexType) + { + return fromValueType.equals(toValueType) && fromKeyType.equals(INTEGER) && toKeyType.equals(BIGINT) && indexType.equals(BIGINT); + } + } +} diff --git a/presto-main/src/main/java/com/facebook/presto/sql/relational/Expressions.java b/presto-main/src/main/java/com/facebook/presto/sql/relational/Expressions.java index 964b0c93bb49d..8596b9c5d35dc 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/relational/Expressions.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/relational/Expressions.java @@ -42,6 +42,8 @@ import static com.facebook.presto.common.type.BigintType.BIGINT; import static com.facebook.presto.common.type.BooleanType.BOOLEAN; +import static com.facebook.presto.common.type.IntegerType.INTEGER; +import static com.facebook.presto.operator.scalar.TryCastFunction.TRY_CAST_NAME; import static com.facebook.presto.spi.relation.SpecialFormExpression.Form.COALESCE; import static com.facebook.presto.spi.relation.SpecialFormExpression.Form.SWITCH; import static com.facebook.presto.sql.analyzer.TypeSignatureProvider.fromTypes; @@ -171,6 +173,19 @@ public static RowExpression castToBigInt(FunctionAndTypeManager functionAndTypeM return call("CAST", functionAndTypeManager.lookupCast(CastType.CAST, rowExpression.getType(), BIGINT), BIGINT, rowExpression); } + public static RowExpression castToInteger(FunctionAndTypeManager functionAndTypeManager, RowExpression rowExpression) + { + if (rowExpression.getType().equals(INTEGER)) { + return rowExpression; + } + return call("CAST", functionAndTypeManager.lookupCast(CastType.CAST, rowExpression.getType(), INTEGER), INTEGER, rowExpression); + } + + public static RowExpression tryCast(FunctionAndTypeManager functionAndTypeManager, RowExpression rowExpression, Type castToType) + { + return call(TRY_CAST_NAME, functionAndTypeManager.lookupCast(CastType.TRY_CAST, rowExpression.getType(), castToType), castToType, rowExpression); + } + public static RowExpression searchedCaseExpression(List whenClauses, Optional defaultValue) { // We rewrite this as - CASE true WHEN p1 THEN v1 WHEN p2 THEN v2 .. ELSE v END diff --git a/presto-main/src/main/java/com/facebook/presto/sql/relational/FunctionResolution.java b/presto-main/src/main/java/com/facebook/presto/sql/relational/FunctionResolution.java index e6a5de633bf30..7d7efada1cc8a 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/relational/FunctionResolution.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/relational/FunctionResolution.java @@ -343,4 +343,9 @@ public boolean isArrayContainsFunction(FunctionHandle functionHandle) { return functionAndTypeResolver.getFunctionMetadata(functionHandle).getName().equals(QualifiedObjectName.valueOf(DEFAULT_NAMESPACE, "contains")); } + + public boolean isElementAtFunction(FunctionHandle functionHandle) + { + return functionAndTypeResolver.getFunctionMetadata(functionHandle).getName().equals(QualifiedObjectName.valueOf(DEFAULT_NAMESPACE, "element_at")); + } } diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/RowExpressionVerifier.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/RowExpressionVerifier.java index 645dedd5b7e54..a7ca4830017df 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/RowExpressionVerifier.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/RowExpressionVerifier.java @@ -183,7 +183,7 @@ protected Boolean visitCast(Cast expected, RowExpression actual) } return getValueFromLiteral(literal).equals(String.valueOf(LiteralInterpreter.evaluate(TEST_SESSION.toConnectorSession(), (ConstantExpression) actual))); } - if (!(actual instanceof CallExpression) || !functionResolution.isCastFunction(((CallExpression) actual).getFunctionHandle())) { + if (!(actual instanceof CallExpression) || (!functionResolution.isCastFunction(((CallExpression) actual).getFunctionHandle()) && !functionResolution.isTryCastFunction(((CallExpression) actual).getFunctionHandle()))) { return false; } diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestRemoveMapCastRule.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestRemoveMapCastRule.java new file mode 100644 index 0000000000000..19638181d35e4 --- /dev/null +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestRemoveMapCastRule.java @@ -0,0 +1,75 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.sql.planner.iterative.rule; + +import com.facebook.presto.spi.relation.VariableReferenceExpression; +import com.facebook.presto.sql.planner.iterative.Rule; +import com.facebook.presto.sql.planner.iterative.rule.test.BaseRuleTest; +import com.google.common.collect.ImmutableMap; +import com.google.common.collect.ImmutableSet; +import org.testng.annotations.Test; + +import static com.facebook.presto.SystemSessionProperties.REMOVE_MAP_CAST; +import static com.facebook.presto.common.type.BigintType.BIGINT; +import static com.facebook.presto.common.type.DoubleType.DOUBLE; +import static com.facebook.presto.common.type.IntegerType.INTEGER; +import static com.facebook.presto.sql.planner.PlannerUtils.createMapType; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.expression; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.project; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.values; +import static com.facebook.presto.sql.planner.iterative.rule.test.PlanBuilder.assignment; + +public class TestRemoveMapCastRule + extends BaseRuleTest +{ + @Test + public void testSubscriptCast() + { + tester().assertThat( + ImmutableSet.>builder().addAll(new SimplifyRowExpressions(getMetadata()).rules()).addAll(new RemoveMapCastRule(getFunctionManager()).rules()).build()) + .setSystemProperty(REMOVE_MAP_CAST, "true") + .on(p -> { + VariableReferenceExpression a = p.variable("a", DOUBLE); + VariableReferenceExpression feature = p.variable("feature", createMapType(getFunctionManager(), INTEGER, DOUBLE)); + VariableReferenceExpression key = p.variable("key", BIGINT); + return p.project( + assignment(a, p.rowExpression("cast(feature as map)[key]")), + p.values(feature, key)); + }) + .matches( + project( + ImmutableMap.of("a", expression("feature[cast(key as integer)]")), + values("feature", "key"))); + } + + @Test + public void testElementAtCast() + { + tester().assertThat( + ImmutableSet.>builder().addAll(new SimplifyRowExpressions(getMetadata()).rules()).addAll(new RemoveMapCastRule(getFunctionManager()).rules()).build()) + .setSystemProperty(REMOVE_MAP_CAST, "true") + .on(p -> { + VariableReferenceExpression a = p.variable("a", DOUBLE); + VariableReferenceExpression feature = p.variable("feature", createMapType(getFunctionManager(), INTEGER, DOUBLE)); + VariableReferenceExpression key = p.variable("key", BIGINT); + return p.project( + assignment(a, p.rowExpression("element_at(cast(feature as map), key)")), + p.values(feature, key)); + }) + .matches( + project( + ImmutableMap.of("a", expression("element_at(feature, try_cast(key as integer))")), + values("feature", "key"))); + } +} diff --git a/presto-tests/src/main/java/com/facebook/presto/tests/AbstractTestQueries.java b/presto-tests/src/main/java/com/facebook/presto/tests/AbstractTestQueries.java index c6daaf62506db..e8b134e6ae4e6 100644 --- a/presto-tests/src/main/java/com/facebook/presto/tests/AbstractTestQueries.java +++ b/presto-tests/src/main/java/com/facebook/presto/tests/AbstractTestQueries.java @@ -72,6 +72,7 @@ import static com.facebook.presto.SystemSessionProperties.QUICK_DISTINCT_LIMIT_ENABLED; import static com.facebook.presto.SystemSessionProperties.RANDOMIZE_OUTER_JOIN_NULL_KEY; import static com.facebook.presto.SystemSessionProperties.RANDOMIZE_OUTER_JOIN_NULL_KEY_STRATEGY; +import static com.facebook.presto.SystemSessionProperties.REMOVE_MAP_CAST; import static com.facebook.presto.SystemSessionProperties.REMOVE_REDUNDANT_CAST_TO_VARCHAR_IN_JOIN; import static com.facebook.presto.SystemSessionProperties.REWRITE_CASE_TO_MAP_ENABLED; import static com.facebook.presto.SystemSessionProperties.REWRITE_CONSTANT_ARRAY_CONTAINS_TO_IN_EXPRESSION; @@ -7485,4 +7486,22 @@ public void testRepeat() assertQuery("select repeat(k1, k2), repeat(k1, 5), repeat(3, k2) from (values (3, 2), (5, 4), (2, 4))t(k1, k2)", "values (array[3, 3], array[3,3,3,3,3], array[3, 3]), (array[5, 5, 5, 5], array[5, 5, 5, 5, 5], array[3, 3, 3, 3]), (array[2, 2, 2, 2], array[2, 2, 2, 2, 2], array[3, 3, 3, 3])"); } + + @Test + public void testRemoveMapCast() + { + Session enableOptimization = Session.builder(getSession()) + .setSystemProperty(REMOVE_MAP_CAST, "true") + .build(); + assertQuery(enableOptimization, "select feature[key] from (values (map(array[cast(1 as integer), 2, 3, 4], array[0.3, 0.5, 0.9, 0.1]), cast(2 as bigint)), (map(array[cast(1 as integer), 2, 3, 4], array[0.3, 0.5, 0.9, 0.1]), 4)) t(feature, key)", + "values 0.5, 0.1"); + assertQuery(enableOptimization, "select element_at(feature, key) from (values (map(array[cast(1 as integer), 2, 3, 4], array[0.3, 0.5, 0.9, 0.1]), cast(2 as bigint)), (map(array[cast(1 as integer), 2, 3, 4], array[0.3, 0.5, 0.9, 0.1]), 4)) t(feature, key)", + "values 0.5, 0.1"); + assertQuery(enableOptimization, "select element_at(feature, key) from (values (map(array[cast(1 as integer), 2, 3, 4], array[0.3, 0.5, 0.9, 0.1]), cast(2 as bigint)), (map(array[cast(1 as integer), 2, 3, 4], array[0.3, 0.5, 0.9, 0.1]), 400000000000)) t(feature, key)", + "values 0.5, null"); + assertQueryFails(enableOptimization, "select feature[key] from (values (map(array[cast(1 as integer), 2, 3, 4], array[0.3, 0.5, 0.9, 0.1]), cast(2 as bigint)), (map(array[cast(1 as integer), 2, 3, 4], array[0.3, 0.5, 0.9, 0.1]), 400000000000)) t(feature, key)", + ".*Out of range for integer.*"); + assertQuery(enableOptimization, "select feature[key] from (values (map(array[cast(1 as varchar), '2', '3', '4'], array[0.3, 0.5, 0.9, 0.1]), cast('2' as varchar)), (map(array[cast(1 as varchar), '2', '3', '4'], array[0.3, 0.5, 0.9, 0.1]), '4')) t(feature, key)", + "values 0.5, 0.1"); + } }