Skip to content

Commit

Permalink
Remove map cast for element access
Browse files Browse the repository at this point in the history
Rewrite expression from cast(feature as map<bigint, float>)[key] -> feature[cast(key as integer)],
where feature is of type map<integer, float> and key is of type bigint, and
element_at(cast(feature as map<bigint, float>), key) -> element_at(feature, try_cast(key as integer)),
where feature is of type map<integer, float> and key is of type bigint, so as to get rid of map cast function.
  • Loading branch information
feilong-liu committed Mar 6, 2024
1 parent d80e49a commit 91aa1b1
Show file tree
Hide file tree
Showing 8 changed files with 262 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -315,6 +315,7 @@ public final class SystemSessionProperties
public static final String USE_PARTIAL_AGGREGATION_HISTORY = "use_partial_aggregation_history";
public static final String TRACK_PARTIAL_AGGREGATION_HISTORY = "track_partial_aggregation_history";
public static final String REMOVE_REDUNDANT_CAST_TO_VARCHAR_IN_JOIN = "remove_redundant_cast_to_varchar_in_join";
public static final String REMOVE_MAP_CAST = "remove_map_cast";
public static final String HANDLE_COMPLEX_EQUI_JOINS = "handle_complex_equi_joins";
public static final String SKIP_HASH_GENERATION_FOR_JOIN_WITH_TABLE_SCAN_INPUT = "skip_hash_generation_for_join_with_table_scan_input";
public static final String GENERATE_DOMAIN_FILTERS = "generate_domain_filters";
Expand Down Expand Up @@ -1907,6 +1908,11 @@ public SystemSessionProperties(
"If both left and right side of join clause are varchar cast from int/bigint, remove the cast here",
featuresConfig.isRemoveRedundantCastToVarcharInJoin(),
false),
booleanProperty(
REMOVE_MAP_CAST,
"Remove map cast when possible",
false,
false),
booleanProperty(
HANDLE_COMPLEX_EQUI_JOINS,
"Handle complex equi-join conditions to open up join space for join reordering",
Expand Down Expand Up @@ -3182,6 +3188,11 @@ public static boolean isRemoveRedundantCastToVarcharInJoinEnabled(Session sessio
return session.getSystemProperty(REMOVE_REDUNDANT_CAST_TO_VARCHAR_IN_JOIN, Boolean.class);
}

public static boolean isRemoveMapCastEnabled(Session session)
{
return session.getSystemProperty(REMOVE_MAP_CAST, Boolean.class);
}

public static boolean shouldHandleComplexEquiJoins(Session session)
{
return session.getSystemProperty(HANDLE_COMPLEX_EQUI_JOINS, Boolean.class);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@
import com.facebook.presto.sql.planner.iterative.rule.RemoveEmptyDelete;
import com.facebook.presto.sql.planner.iterative.rule.RemoveFullSample;
import com.facebook.presto.sql.planner.iterative.rule.RemoveIdentityProjectionsBelowProjection;
import com.facebook.presto.sql.planner.iterative.rule.RemoveMapCastRule;
import com.facebook.presto.sql.planner.iterative.rule.RemoveRedundantAggregateDistinct;
import com.facebook.presto.sql.planner.iterative.rule.RemoveRedundantCastToVarcharInJoinClause;
import com.facebook.presto.sql.planner.iterative.rule.RemoveRedundantDistinct;
Expand Down Expand Up @@ -496,7 +497,8 @@ public PlanOptimizers(
ruleStats,
statsCalculator,
estimatedExchangesCostCalculator,
ImmutableSet.of(new RemoveRedundantCastToVarcharInJoinClause(metadata.getFunctionAndTypeManager()))));
ImmutableSet.<Rule<?>>builder().add(new RemoveRedundantCastToVarcharInJoinClause(metadata.getFunctionAndTypeManager()))
.addAll(new RemoveMapCastRule(metadata.getFunctionAndTypeManager()).rules()).build()));

builder.add(new IterativeOptimizer(
metadata,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.facebook.presto.sql.planner.iterative.rule;

import com.facebook.presto.Session;
import com.facebook.presto.common.type.MapType;
import com.facebook.presto.common.type.Type;
import com.facebook.presto.expressions.RowExpressionRewriter;
import com.facebook.presto.expressions.RowExpressionTreeRewriter;
import com.facebook.presto.metadata.FunctionAndTypeManager;
import com.facebook.presto.spi.relation.CallExpression;
import com.facebook.presto.spi.relation.RowExpression;
import com.facebook.presto.sql.planner.iterative.Rule;
import com.facebook.presto.sql.relational.FunctionResolution;
import com.google.common.collect.ImmutableSet;

import java.util.Set;

import static com.facebook.presto.SystemSessionProperties.isRemoveMapCastEnabled;
import static com.facebook.presto.common.function.OperatorType.SUBSCRIPT;
import static com.facebook.presto.common.type.BigintType.BIGINT;
import static com.facebook.presto.common.type.IntegerType.INTEGER;
import static com.facebook.presto.sql.relational.Expressions.call;
import static com.facebook.presto.sql.relational.Expressions.castToInteger;
import static com.facebook.presto.sql.relational.Expressions.tryCast;
import static java.util.Objects.requireNonNull;

/**
* Remove cast on map if possible. Currently it only supports subscript and element_at function, and only works when map key is of type integer and index is bigint. For example:
* Input: cast(feature as map<bigint, float>)[key], where feature is of type map<integer, float> and key is of type bigint
* Output: feature[cast(key as integer)]
*
* Input: element_at(cast(feature as map<bigint, float>), key), where feature is of type map<integer, float> and key is of type bigint
* Output: element_at(feature, try_cast(key as integer))
*
* Notice that here when it's accessing the map using subscript function, we use CAST function in index, and when it's element_at function, we use TRY_CAST function, so that
* when the key is out of integer range, for feature[key] it will fail both with and without optimization, fail with map key not exists before optimization and with cast failure after optimization
* when the key is out of integer range, for element_at(feature, key) it will return NULL both before and after optimization
*/
public class RemoveMapCastRule
extends RowExpressionRewriteRuleSet
{
public RemoveMapCastRule(FunctionAndTypeManager functionAndTypeManager)
{
super(new RemoveMapCastRule.Rewriter(functionAndTypeManager));
}

@Override
public boolean isRewriterEnabled(Session session)
{
return isRemoveMapCastEnabled(session);
}

@Override
public Set<Rule<?>> rules()
{
return ImmutableSet.of(filterRowExpressionRewriteRule(), projectRowExpressionRewriteRule());
}

private static class Rewriter
implements PlanRowExpressionRewriter
{
private final RemoveMapCastRewriter removeMapCastRewriter;

public Rewriter(FunctionAndTypeManager functionAndTypeManager)
{
requireNonNull(functionAndTypeManager, "functionAndTypeManager is null");
this.removeMapCastRewriter = new RemoveMapCastRewriter(functionAndTypeManager);
}

@Override
public RowExpression rewrite(RowExpression expression, Rule.Context context)
{
return RowExpressionTreeRewriter.rewriteWith(removeMapCastRewriter, expression);
}
}

private static class RemoveMapCastRewriter
extends RowExpressionRewriter<Void>
{
private final FunctionAndTypeManager functionAndTypeManager;
private final FunctionResolution functionResolution;

private RemoveMapCastRewriter(FunctionAndTypeManager functionAndTypeManager)
{
this.functionAndTypeManager = functionAndTypeManager;
this.functionResolution = new FunctionResolution(functionAndTypeManager.getFunctionAndTypeResolver());
}

@Override
public RowExpression rewriteCall(CallExpression node, Void context, RowExpressionTreeRewriter<Void> treeRewriter)
{
if ((functionResolution.isSubscriptFunction(node.getFunctionHandle()) || functionResolution.isElementAtFunction(node.getFunctionHandle())) && node.getArguments().get(0) instanceof CallExpression
&& functionResolution.isCastFunction(((CallExpression) node.getArguments().get(0)).getFunctionHandle())
&& ((CallExpression) node.getArguments().get(0)).getArguments().get(0).getType() instanceof MapType) {
CallExpression castExpression = (CallExpression) node.getArguments().get(0);
RowExpression castInput = castExpression.getArguments().get(0);
Type fromKeyType = ((MapType) castInput.getType()).getKeyType();
Type fromValueType = ((MapType) castInput.getType()).getValueType();
Type toKeyType = ((MapType) castExpression.getType()).getKeyType();
Type toValueType = ((MapType) castExpression.getType()).getValueType();

if (canRemoveMapCast(fromKeyType, fromValueType, toKeyType, toValueType, node.getArguments().get(1).getType())) {
if (functionResolution.isSubscriptFunction(node.getFunctionHandle())) {
RowExpression newIndex = castToInteger(functionAndTypeManager, node.getArguments().get(1));
return call(SUBSCRIPT.name(), functionResolution.subscriptFunction(castInput.getType(), newIndex.getType()), node.getType(), castInput, newIndex);
}
else {
RowExpression newIndex = tryCast(functionAndTypeManager, node.getArguments().get(1), INTEGER);
return call(functionAndTypeManager, "element_at", node.getType(), castInput, newIndex);
}
}
}
return null;
}

private static boolean canRemoveMapCast(Type fromKeyType, Type fromValueType, Type toKeyType, Type toValueType, Type indexType)
{
return fromValueType.equals(toValueType) && fromKeyType.equals(INTEGER) && toKeyType.equals(BIGINT) && indexType.equals(BIGINT);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@

import static com.facebook.presto.common.type.BigintType.BIGINT;
import static com.facebook.presto.common.type.BooleanType.BOOLEAN;
import static com.facebook.presto.common.type.IntegerType.INTEGER;
import static com.facebook.presto.operator.scalar.TryCastFunction.TRY_CAST_NAME;
import static com.facebook.presto.spi.relation.SpecialFormExpression.Form.COALESCE;
import static com.facebook.presto.spi.relation.SpecialFormExpression.Form.SWITCH;
import static com.facebook.presto.sql.analyzer.TypeSignatureProvider.fromTypes;
Expand Down Expand Up @@ -171,6 +173,19 @@ public static RowExpression castToBigInt(FunctionAndTypeManager functionAndTypeM
return call("CAST", functionAndTypeManager.lookupCast(CastType.CAST, rowExpression.getType(), BIGINT), BIGINT, rowExpression);
}

public static RowExpression castToInteger(FunctionAndTypeManager functionAndTypeManager, RowExpression rowExpression)
{
if (rowExpression.getType().equals(INTEGER)) {
return rowExpression;
}
return call("CAST", functionAndTypeManager.lookupCast(CastType.CAST, rowExpression.getType(), INTEGER), INTEGER, rowExpression);
}

public static RowExpression tryCast(FunctionAndTypeManager functionAndTypeManager, RowExpression rowExpression, Type castToType)
{
return call(TRY_CAST_NAME, functionAndTypeManager.lookupCast(CastType.TRY_CAST, rowExpression.getType(), castToType), castToType, rowExpression);
}

public static RowExpression searchedCaseExpression(List<RowExpression> whenClauses, Optional<RowExpression> defaultValue)
{
// We rewrite this as - CASE true WHEN p1 THEN v1 WHEN p2 THEN v2 .. ELSE v END
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -343,4 +343,9 @@ public boolean isArrayContainsFunction(FunctionHandle functionHandle)
{
return functionAndTypeResolver.getFunctionMetadata(functionHandle).getName().equals(QualifiedObjectName.valueOf(DEFAULT_NAMESPACE, "contains"));
}

public boolean isElementAtFunction(FunctionHandle functionHandle)
{
return functionAndTypeResolver.getFunctionMetadata(functionHandle).getName().equals(QualifiedObjectName.valueOf(DEFAULT_NAMESPACE, "element_at"));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@ protected Boolean visitCast(Cast expected, RowExpression actual)
}
return getValueFromLiteral(literal).equals(String.valueOf(LiteralInterpreter.evaluate(TEST_SESSION.toConnectorSession(), (ConstantExpression) actual)));
}
if (!(actual instanceof CallExpression) || !functionResolution.isCastFunction(((CallExpression) actual).getFunctionHandle())) {
if (!(actual instanceof CallExpression) || (!functionResolution.isCastFunction(((CallExpression) actual).getFunctionHandle()) && !functionResolution.isTryCastFunction(((CallExpression) actual).getFunctionHandle()))) {
return false;
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.facebook.presto.sql.planner.iterative.rule;

import com.facebook.presto.spi.relation.VariableReferenceExpression;
import com.facebook.presto.sql.planner.iterative.Rule;
import com.facebook.presto.sql.planner.iterative.rule.test.BaseRuleTest;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import org.testng.annotations.Test;

import static com.facebook.presto.SystemSessionProperties.REMOVE_MAP_CAST;
import static com.facebook.presto.common.type.BigintType.BIGINT;
import static com.facebook.presto.common.type.DoubleType.DOUBLE;
import static com.facebook.presto.common.type.IntegerType.INTEGER;
import static com.facebook.presto.sql.planner.PlannerUtils.createMapType;
import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.expression;
import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.project;
import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.values;
import static com.facebook.presto.sql.planner.iterative.rule.test.PlanBuilder.assignment;

public class TestRemoveMapCastRule
extends BaseRuleTest
{
@Test
public void testSubscriptCast()
{
tester().assertThat(
ImmutableSet.<Rule<?>>builder().addAll(new SimplifyRowExpressions(getMetadata()).rules()).addAll(new RemoveMapCastRule(getFunctionManager()).rules()).build())
.setSystemProperty(REMOVE_MAP_CAST, "true")
.on(p -> {
VariableReferenceExpression a = p.variable("a", DOUBLE);
VariableReferenceExpression feature = p.variable("feature", createMapType(getFunctionManager(), INTEGER, DOUBLE));
VariableReferenceExpression key = p.variable("key", BIGINT);
return p.project(
assignment(a, p.rowExpression("cast(feature as map<bigint, double>)[key]")),
p.values(feature, key));
})
.matches(
project(
ImmutableMap.of("a", expression("feature[cast(key as integer)]")),
values("feature", "key")));
}

@Test
public void testElementAtCast()
{
tester().assertThat(
ImmutableSet.<Rule<?>>builder().addAll(new SimplifyRowExpressions(getMetadata()).rules()).addAll(new RemoveMapCastRule(getFunctionManager()).rules()).build())
.setSystemProperty(REMOVE_MAP_CAST, "true")
.on(p -> {
VariableReferenceExpression a = p.variable("a", DOUBLE);
VariableReferenceExpression feature = p.variable("feature", createMapType(getFunctionManager(), INTEGER, DOUBLE));
VariableReferenceExpression key = p.variable("key", BIGINT);
return p.project(
assignment(a, p.rowExpression("element_at(cast(feature as map<bigint, double>), key)")),
p.values(feature, key));
})
.matches(
project(
ImmutableMap.of("a", expression("element_at(feature, try_cast(key as integer))")),
values("feature", "key")));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@
import static com.facebook.presto.SystemSessionProperties.QUICK_DISTINCT_LIMIT_ENABLED;
import static com.facebook.presto.SystemSessionProperties.RANDOMIZE_OUTER_JOIN_NULL_KEY;
import static com.facebook.presto.SystemSessionProperties.RANDOMIZE_OUTER_JOIN_NULL_KEY_STRATEGY;
import static com.facebook.presto.SystemSessionProperties.REMOVE_MAP_CAST;
import static com.facebook.presto.SystemSessionProperties.REMOVE_REDUNDANT_CAST_TO_VARCHAR_IN_JOIN;
import static com.facebook.presto.SystemSessionProperties.REWRITE_CASE_TO_MAP_ENABLED;
import static com.facebook.presto.SystemSessionProperties.REWRITE_CONSTANT_ARRAY_CONTAINS_TO_IN_EXPRESSION;
Expand Down Expand Up @@ -7485,4 +7486,22 @@ public void testRepeat()
assertQuery("select repeat(k1, k2), repeat(k1, 5), repeat(3, k2) from (values (3, 2), (5, 4), (2, 4))t(k1, k2)",
"values (array[3, 3], array[3,3,3,3,3], array[3, 3]), (array[5, 5, 5, 5], array[5, 5, 5, 5, 5], array[3, 3, 3, 3]), (array[2, 2, 2, 2], array[2, 2, 2, 2, 2], array[3, 3, 3, 3])");
}

@Test
public void testRemoveMapCast()
{
Session enableOptimization = Session.builder(getSession())
.setSystemProperty(REMOVE_MAP_CAST, "true")
.build();
assertQuery(enableOptimization, "select feature[key] from (values (map(array[cast(1 as integer), 2, 3, 4], array[0.3, 0.5, 0.9, 0.1]), cast(2 as bigint)), (map(array[cast(1 as integer), 2, 3, 4], array[0.3, 0.5, 0.9, 0.1]), 4)) t(feature, key)",
"values 0.5, 0.1");
assertQuery(enableOptimization, "select element_at(feature, key) from (values (map(array[cast(1 as integer), 2, 3, 4], array[0.3, 0.5, 0.9, 0.1]), cast(2 as bigint)), (map(array[cast(1 as integer), 2, 3, 4], array[0.3, 0.5, 0.9, 0.1]), 4)) t(feature, key)",
"values 0.5, 0.1");
assertQuery(enableOptimization, "select element_at(feature, key) from (values (map(array[cast(1 as integer), 2, 3, 4], array[0.3, 0.5, 0.9, 0.1]), cast(2 as bigint)), (map(array[cast(1 as integer), 2, 3, 4], array[0.3, 0.5, 0.9, 0.1]), 400000000000)) t(feature, key)",
"values 0.5, null");
assertQueryFails(enableOptimization, "select feature[key] from (values (map(array[cast(1 as integer), 2, 3, 4], array[0.3, 0.5, 0.9, 0.1]), cast(2 as bigint)), (map(array[cast(1 as integer), 2, 3, 4], array[0.3, 0.5, 0.9, 0.1]), 400000000000)) t(feature, key)",
".*Out of range for integer.*");
assertQuery(enableOptimization, "select feature[key] from (values (map(array[cast(1 as varchar), '2', '3', '4'], array[0.3, 0.5, 0.9, 0.1]), cast('2' as varchar)), (map(array[cast(1 as varchar), '2', '3', '4'], array[0.3, 0.5, 0.9, 0.1]), '4')) t(feature, key)",
"values 0.5, 0.1");
}
}

0 comments on commit 91aa1b1

Please sign in to comment.