From 84cafecb80f2300795264de603a4e97e6e3b4c84 Mon Sep 17 00:00:00 2001 From: Anant Aneja <1797669+aaneja@users.noreply.github.com> Date: Tue, 21 Jan 2025 17:57:10 +0530 Subject: [PATCH] Add Exchange before GroupId to improve Partial Aggregation Based on: https://github.com/trinodb/trino/commit/dc1d66fb co-authored-by: Piotr Findeisen Based on : https://github.com/trinodb/trino/commit/c573b34ef42adf79935df302f133070f3bf0c82a co-authored-by: Lukasz Stec Based on: https://github.com/trinodb/trino/commit/29328d376abc16ad597347f967f8ff556a14363f co-authored-by: praveenkrishna --- .../presto/SystemSessionProperties.java | 10 + .../presto/cost/TaskCountEstimator.java | 8 + .../presto/sql/analyzer/FeaturesConfig.java | 14 + .../presto/sql/planner/PlanOptimizers.java | 21 +- ...wPartialAggregationOverGroupIdRuleSet.java | 361 ++++++++++++++++++ .../presto/sql/planner/plan/GroupIdNode.java | 4 +- .../presto/testing/LocalQueryRunner.java | 20 +- .../sql/analyzer/TestFeaturesConfig.java | 7 +- ...wPartialAggregationOverGroupIdRuleSet.java | 99 +++++ .../sql/planner/assertions/BasePlanTest.java | 15 +- .../planner/assertions/ExchangeMatcher.java | 20 +- .../planner/assertions/PlanMatchPattern.java | 7 +- ...wPartialAggregationOverGroupIdRuleSet.java | 224 +++++++++++ .../iterative/rule/test/PlanBuilder.java | 29 ++ .../tests/AbstractTestQueryFramework.java | 4 +- 15 files changed, 827 insertions(+), 16 deletions(-) create mode 100644 presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/AddExchangesBelowPartialAggregationOverGroupIdRuleSet.java create mode 100644 presto-main/src/test/java/com/facebook/presto/sql/planner/TestLogicalAddExchangesBelowPartialAggregationOverGroupIdRuleSet.java create mode 100644 presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestAddExchangesBelowPartialAggregationOverGroupIdRuleSet.java diff --git a/presto-main/src/main/java/com/facebook/presto/SystemSessionProperties.java b/presto-main/src/main/java/com/facebook/presto/SystemSessionProperties.java index a86d4ae6cbe18..ae8698136bad7 100644 --- a/presto-main/src/main/java/com/facebook/presto/SystemSessionProperties.java +++ b/presto-main/src/main/java/com/facebook/presto/SystemSessionProperties.java @@ -328,6 +328,7 @@ public final class SystemSessionProperties public static final String INCLUDE_VALUES_NODE_IN_CONNECTOR_OPTIMIZER = "include_values_node_in_connector_optimizer"; public static final String SINGLE_NODE_EXECUTION_ENABLED = "single_node_execution_enabled"; public static final String EXPRESSION_OPTIMIZER_NAME = "expression_optimizer_name"; + public static final String ADD_EXCHANGE_BELOW_PARTIAL_AGGREGATION_OVER_GROUP_ID = "add_exchange_below_partial_aggregation_over_group_id"; // TODO: Native execution related session properties that are temporarily put here. They will be relocated in the future. public static final String NATIVE_AGGREGATION_SPILL_ALL = "native_aggregation_spill_all"; @@ -1858,6 +1859,10 @@ public SystemSessionProperties( EXPRESSION_OPTIMIZER_NAME, "Configure which expression optimizer to use", featuresConfig.getExpressionOptimizerName(), + false), + booleanProperty(ADD_EXCHANGE_BELOW_PARTIAL_AGGREGATION_OVER_GROUP_ID, + "Enable adding an exchange below partial aggregation over a GroupId node to improve partial aggregation performance", + featuresConfig.getAddExchangeBelowPartialAggregationOverGroupId(), false)); } @@ -3164,4 +3169,9 @@ public static String getExpressionOptimizerName(Session session) { return session.getSystemProperty(EXPRESSION_OPTIMIZER_NAME, String.class); } + + public static boolean isEnabledAddExchangeBelowGroupId(Session session) + { + return session.getSystemProperty(ADD_EXCHANGE_BELOW_PARTIAL_AGGREGATION_OVER_GROUP_ID, Boolean.class); + } } diff --git a/presto-main/src/main/java/com/facebook/presto/cost/TaskCountEstimator.java b/presto-main/src/main/java/com/facebook/presto/cost/TaskCountEstimator.java index 3f192bea26eb4..8e6228078068a 100644 --- a/presto-main/src/main/java/com/facebook/presto/cost/TaskCountEstimator.java +++ b/presto-main/src/main/java/com/facebook/presto/cost/TaskCountEstimator.java @@ -13,6 +13,7 @@ */ package com.facebook.presto.cost; +import com.facebook.presto.Session; import com.facebook.presto.execution.scheduler.NodeSchedulerConfig; import com.facebook.presto.metadata.InternalNode; import com.facebook.presto.metadata.InternalNodeManager; @@ -22,6 +23,8 @@ import java.util.Set; import java.util.function.IntSupplier; +import static com.facebook.presto.SystemSessionProperties.getHashPartitionCount; +import static java.lang.Math.min; import static java.lang.Math.toIntExact; import static java.util.Objects.requireNonNull; @@ -54,4 +57,9 @@ public int estimateSourceDistributedTaskCount() { return numberOfNodes.getAsInt(); } + + public int estimateHashedTaskCount(Session session) + { + return min(numberOfNodes.getAsInt(), getHashPartitionCount(session)); + } } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/analyzer/FeaturesConfig.java b/presto-main/src/main/java/com/facebook/presto/sql/analyzer/FeaturesConfig.java index 033c6b26f489d..8d24e8673e35e 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/analyzer/FeaturesConfig.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/analyzer/FeaturesConfig.java @@ -297,6 +297,7 @@ public class FeaturesConfig private boolean singleNodeExecutionEnabled; private boolean nativeExecutionScaleWritersThreadsEnabled; private String expressionOptimizerName = DEFAULT_EXPRESSION_OPTIMIZER_NAME; + private boolean addExchangeBelowPartialAggregationOverGroupId; public enum PartitioningPrecisionStrategy { @@ -2945,4 +2946,17 @@ public boolean isExcludeInvalidWorkerSessionProperties() { return this.setExcludeInvalidWorkerSessionProperties; } + + @Config("optimizer.add-exchange-below-partial-aggregation-over-group-id") + @ConfigDescription("Enable adding an exchange below partial aggregation over a GroupId node to improve partial aggregation performance") + public FeaturesConfig setAddExchangeBelowPartialAggregationOverGroupId(boolean addExchangeBelowPartialAggregationOverGroupId) + { + this.addExchangeBelowPartialAggregationOverGroupId = addExchangeBelowPartialAggregationOverGroupId; + return this; + } + + public boolean getAddExchangeBelowPartialAggregationOverGroupId() + { + return addExchangeBelowPartialAggregationOverGroupId; + } } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/PlanOptimizers.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/PlanOptimizers.java index 25607c03c9c86..a7ec8cb353dfd 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/PlanOptimizers.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/PlanOptimizers.java @@ -18,6 +18,7 @@ import com.facebook.presto.cost.CostComparator; import com.facebook.presto.cost.StatsCalculator; import com.facebook.presto.cost.TaskCountEstimator; +import com.facebook.presto.execution.TaskManagerConfig; import com.facebook.presto.metadata.Metadata; import com.facebook.presto.split.PageSourceManager; import com.facebook.presto.split.SplitManager; @@ -27,6 +28,7 @@ import com.facebook.presto.sql.planner.iterative.IterativeOptimizer; import com.facebook.presto.sql.planner.iterative.Rule; import com.facebook.presto.sql.planner.iterative.properties.LogicalPropertiesProviderImpl; +import com.facebook.presto.sql.planner.iterative.rule.AddExchangesBelowPartialAggregationOverGroupIdRuleSet; import com.facebook.presto.sql.planner.iterative.rule.AddIntermediateAggregations; import com.facebook.presto.sql.planner.iterative.rule.AddNotNullFiltersToJoinNode; import com.facebook.presto.sql.planner.iterative.rule.CombineApproxPercentileFunctions; @@ -222,7 +224,8 @@ public PlanOptimizers( TaskCountEstimator taskCountEstimator, PartitioningProviderManager partitioningProviderManager, FeaturesConfig featuresConfig, - ExpressionOptimizerManager expressionOptimizerManager) + ExpressionOptimizerManager expressionOptimizerManager, + TaskManagerConfig taskManagerConfig) { this(metadata, sqlParser, @@ -238,7 +241,8 @@ public PlanOptimizers( taskCountEstimator, partitioningProviderManager, featuresConfig, - expressionOptimizerManager); + expressionOptimizerManager, + taskManagerConfig); } @PostConstruct @@ -270,7 +274,8 @@ public PlanOptimizers( TaskCountEstimator taskCountEstimator, PartitioningProviderManager partitioningProviderManager, FeaturesConfig featuresConfig, - ExpressionOptimizerManager expressionOptimizerManager) + ExpressionOptimizerManager expressionOptimizerManager, + TaskManagerConfig taskManagerConfig) { this.exporter = exporter; ImmutableList.Builder builder = ImmutableList.builder(); @@ -820,6 +825,7 @@ public PlanOptimizers( if (!noExchange) { builder.add(new ReplicateSemiJoinInDelete()); // Must run before AddExchanges + builder.add(new IterativeOptimizer( metadata, ruleStats, @@ -830,6 +836,7 @@ public PlanOptimizers( // Must run before AddExchanges and after ReplicateSemiJoinInDelete // to avoid temporarily having an invalid plan new DetermineSemiJoinDistributionType(costComparator, taskCountEstimator)))); + builder.add(new RandomizeNullKeyInOuterJoin(metadata.getFunctionAndTypeManager(), statsCalculator), new PruneUnreferencedOutputs(), new IterativeOptimizer( @@ -841,6 +848,7 @@ public PlanOptimizers( new PruneRedundantProjectionAssignments(), new InlineProjections(metadata.getFunctionAndTypeManager()), new RemoveRedundantIdentityProjections()))); + builder.add(new ShardJoins(metadata, metadata.getFunctionAndTypeManager(), statsCalculator), new PruneUnreferencedOutputs()); builder.add( @@ -914,6 +922,13 @@ public PlanOptimizers( ImmutableSet.of( new PruneJoinColumns()))); + builder.add(new IterativeOptimizer( + metadata, + ruleStats, + statsCalculator, + costCalculator, + new AddExchangesBelowPartialAggregationOverGroupIdRuleSet(taskCountEstimator, taskManagerConfig, metadata).rules())); + builder.add(new IterativeOptimizer( metadata, ruleStats, diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/AddExchangesBelowPartialAggregationOverGroupIdRuleSet.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/AddExchangesBelowPartialAggregationOverGroupIdRuleSet.java new file mode 100644 index 0000000000000..1d45d2e715025 --- /dev/null +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/AddExchangesBelowPartialAggregationOverGroupIdRuleSet.java @@ -0,0 +1,361 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.sql.planner.iterative.rule; + +import com.facebook.presto.Session; +import com.facebook.presto.cost.PlanNodeStatsEstimate; +import com.facebook.presto.cost.TaskCountEstimator; +import com.facebook.presto.cost.VariableStatsEstimate; +import com.facebook.presto.execution.TaskManagerConfig; +import com.facebook.presto.matching.Capture; +import com.facebook.presto.matching.Captures; +import com.facebook.presto.matching.Pattern; +import com.facebook.presto.metadata.Metadata; +import com.facebook.presto.spi.plan.AggregationNode; +import com.facebook.presto.spi.plan.Partitioning; +import com.facebook.presto.spi.plan.PartitioningScheme; +import com.facebook.presto.spi.plan.PlanNode; +import com.facebook.presto.spi.plan.ProjectNode; +import com.facebook.presto.spi.relation.VariableReferenceExpression; +import com.facebook.presto.sql.planner.iterative.Rule; +import com.facebook.presto.sql.planner.optimizations.StreamPreferredProperties; +import com.facebook.presto.sql.planner.optimizations.StreamPropertyDerivations.StreamProperties; +import com.facebook.presto.sql.planner.plan.ExchangeNode; +import com.facebook.presto.sql.planner.plan.GroupIdNode; +import com.facebook.presto.sql.relational.ProjectNodeUtils; +import com.google.common.annotations.VisibleForTesting; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableSet; +import com.google.common.collect.Multiset; +import io.airlift.units.DataSize; + +import java.util.Collection; +import java.util.List; +import java.util.Objects; +import java.util.Optional; +import java.util.Set; + +import static com.facebook.presto.SystemSessionProperties.getTaskConcurrency; +import static com.facebook.presto.SystemSessionProperties.isEnabledAddExchangeBelowGroupId; +import static com.facebook.presto.matching.Capture.newCapture; +import static com.facebook.presto.matching.Pattern.nonEmpty; +import static com.facebook.presto.matching.Pattern.typeOf; +import static com.facebook.presto.sql.planner.SystemPartitioningHandle.FIXED_HASH_DISTRIBUTION; +import static com.facebook.presto.sql.planner.optimizations.StreamPreferredProperties.fixedParallelism; +import static com.facebook.presto.sql.planner.optimizations.StreamPropertyDerivations.deriveProperties; +import static com.facebook.presto.sql.planner.plan.ExchangeNode.Scope.LOCAL; +import static com.facebook.presto.sql.planner.plan.ExchangeNode.Scope.REMOTE_STREAMING; +import static com.facebook.presto.sql.planner.plan.ExchangeNode.partitionedExchange; +import static com.facebook.presto.sql.planner.plan.Patterns.Aggregation.groupingColumns; +import static com.facebook.presto.sql.planner.plan.Patterns.Aggregation.step; +import static com.facebook.presto.sql.planner.plan.Patterns.source; +import static com.google.common.base.Preconditions.checkArgument; +import static com.google.common.base.Verify.verify; +import static com.google.common.collect.ImmutableList.toImmutableList; +import static com.google.common.collect.ImmutableMultiset.toImmutableMultiset; +import static com.google.common.collect.ImmutableSet.toImmutableSet; +import static java.lang.Double.isNaN; +import static java.lang.Math.min; +import static java.util.Comparator.comparing; +import static java.util.Objects.requireNonNull; + +/** + * Transforms + *
+ *   - Exchange
+ *     - [ Projection ]
+ *       - Partial Aggregation
+ *         - GroupId
+ * 
+ * to + *
+ *   - Exchange
+ *     - [ Projection ]
+ *       - Partial Aggregation
+ *         - GroupId
+ *           - LocalExchange
+ *             - RemoteExchange
+ * 
+ *

+ * Rationale: GroupId increases number of rows (number of times equal to number of grouping sets) and then + * partial aggregation reduces number of rows. However, under certain conditions, exchanging the rows before + * GroupId (before multiplication) makes partial aggregation more effective, resulting in less data being + * exchanged afterwards. + */ +public class AddExchangesBelowPartialAggregationOverGroupIdRuleSet +{ + private static final Capture PROJECTION = newCapture(); + private static final Capture AGGREGATION = newCapture(); + private static final Capture GROUP_ID = newCapture(); + private static final Capture REMOTE_EXCHANGE = newCapture(); + + private static final Pattern WITH_PROJECTION = + // If there was no exchange here, adding new exchanges could break property derivations logic of AddExchanges, AddLocalExchanges + typeOf(ExchangeNode.class) + .matching(e -> e.getScope().isRemote()).capturedAs(REMOTE_EXCHANGE) + .with(source().matching( + // PushPartialAggregationThroughExchange adds a projection. However, it can be removed if RemoveRedundantIdentityProjections is run in the mean-time. + typeOf(ProjectNode.class).matching(ProjectNodeUtils::isIdentity).capturedAs(PROJECTION) + .with(source().matching( + typeOf(AggregationNode.class).capturedAs(AGGREGATION) + .with(step().equalTo(AggregationNode.Step.PARTIAL)) + .with(nonEmpty(groupingColumns())) + .with(source().matching( + typeOf(GroupIdNode.class).capturedAs(GROUP_ID))))))); + + private static final Pattern WITHOUT_PROJECTION = + // If there was no exchange here, adding new exchanges could break property derivations logic of AddExchanges, AddLocalExchanges + typeOf(ExchangeNode.class) + .matching(e -> e.getScope().isRemote()).capturedAs(REMOTE_EXCHANGE) + .with(source().matching( + typeOf(AggregationNode.class).capturedAs(AGGREGATION) + .with(step().equalTo(AggregationNode.Step.PARTIAL)) + .with(nonEmpty(groupingColumns())) + .with(source().matching( + typeOf(GroupIdNode.class).capturedAs(GROUP_ID))))); + + private static final double GROUPING_SETS_SYMBOL_REQUIRED_FREQUENCY = 0.5; + private static final double ANTI_SKEWNESS_MARGIN = 3; + private final TaskCountEstimator taskCountEstimator; + private final DataSize maxPartialAggregationMemoryUsage; + private final Metadata metadata; + + public AddExchangesBelowPartialAggregationOverGroupIdRuleSet( + TaskCountEstimator taskCountEstimator, + TaskManagerConfig taskManagerConfig, + Metadata metadata) + { + this.taskCountEstimator = requireNonNull(taskCountEstimator, "taskCountEstimator is null"); + this.maxPartialAggregationMemoryUsage = taskManagerConfig.getMaxPartialAggregationMemoryUsage(); + this.metadata = metadata; + } + + public Set> rules() + { + return ImmutableSet.of( + belowProjectionRule(), + belowExchangeRule()); + } + + @VisibleForTesting + AddExchangesBelowExchangePartialAggregationGroupId belowExchangeRule() + { + return new AddExchangesBelowExchangePartialAggregationGroupId(); + } + + @VisibleForTesting + AddExchangesBelowProjectionPartialAggregationGroupId belowProjectionRule() + { + return new AddExchangesBelowProjectionPartialAggregationGroupId(); + } + + @VisibleForTesting + class AddExchangesBelowProjectionPartialAggregationGroupId + extends BaseAddExchangesBelowExchangePartialAggregationGroupId + { + @Override + public Pattern getPattern() + { + return WITH_PROJECTION; + } + + @Override + public Result apply(ExchangeNode exchange, Captures captures, Context context) + { + ProjectNode project = captures.get(PROJECTION); + AggregationNode aggregation = captures.get(AGGREGATION); + GroupIdNode groupId = captures.get(GROUP_ID); + return transform(aggregation, groupId, context) + .map(newAggregation -> Result.ofPlanNode( + exchange.replaceChildren(ImmutableList.of( + project.replaceChildren(ImmutableList.of( + newAggregation)))))) + .orElseGet(Result::empty); + } + } + + @VisibleForTesting + class AddExchangesBelowExchangePartialAggregationGroupId + extends BaseAddExchangesBelowExchangePartialAggregationGroupId + { + @Override + public Pattern getPattern() + { + return WITHOUT_PROJECTION; + } + + @Override + public Result apply(ExchangeNode exchange, Captures captures, Context context) + { + AggregationNode aggregation = captures.get(AGGREGATION); + GroupIdNode groupId = captures.get(GROUP_ID); + return transform(aggregation, groupId, context) + .map(newAggregation -> { + PlanNode newExchange = exchange.replaceChildren(ImmutableList.of(newAggregation)); + return Result.ofPlanNode(newExchange); + }) + .orElseGet(Result::empty); + } + } + + private abstract class BaseAddExchangesBelowExchangePartialAggregationGroupId + implements Rule + { + @Override + public boolean isEnabled(Session session) + { + return isEnabledAddExchangeBelowGroupId(session); + } + + protected Optional transform(AggregationNode aggregation, GroupIdNode groupId, Context context) + { + Set groupingKeys = aggregation.getGroupingKeys().stream() + .filter(symbol -> !groupId.getGroupIdVariable().equals(symbol)) + .collect(toImmutableSet()); + + Multiset groupingSetHistogram = groupId.getGroupingSets().stream() + .flatMap(Collection::stream) + .collect(toImmutableMultiset()); + + if (!Objects.equals(groupingSetHistogram.elementSet(), groupingKeys)) { + // TODO handle the case when some aggregation keys are pass-through in GroupId (e.g. common in all grouping sets) + // TODO handle the case when some grouping set symbols are not used in aggregation (possible?) + return Optional.empty(); + } + + double aggregationMemoryRequirements = estimateAggregationMemoryRequirements(groupingKeys, groupId, groupingSetHistogram, context); + if (isNaN(aggregationMemoryRequirements) || aggregationMemoryRequirements < maxPartialAggregationMemoryUsage.toBytes()) { + // Aggregation will be effective even without exchanges (or we have insufficient information). + return Optional.empty(); + } + + List desiredHashVariables = groupingSetHistogram.entrySet().stream() + // Take only frequently used symbols + .filter(entry -> entry.getCount() >= groupId.getGroupingSets().size() * GROUPING_SETS_SYMBOL_REQUIRED_FREQUENCY) + .map(Multiset.Entry::getElement) + // And only the symbols used in the aggregation (these are usually all symbols) + .peek(symbol -> verify(groupingKeys.contains(symbol), "%s not found in the grouping keys [%s]", symbol, groupingKeys)) + // Transform to symbols before GroupId + .map(groupId.getGroupingColumns()::get) + .collect(toImmutableList()); + + // Use only the symbol with the highest cardinality (if we have statistics). This makes partial aggregation more efficient in case of + // low correlation between symbol that are in every grouping set vs additional symbols. + PlanNodeStatsEstimate sourceStats = context.getStatsProvider().getStats(groupId.getSource()); + desiredHashVariables = desiredHashVariables.stream() + .filter(symbol -> !isNaN(sourceStats.getVariableStatistics(symbol).getDistinctValuesCount())) + .max(comparing(symbol -> sourceStats.getVariableStatistics(symbol).getDistinctValuesCount())) + .map(symbol -> (List) ImmutableList.of(symbol)).orElse(desiredHashVariables); + + StreamPreferredProperties requiredProperties = fixedParallelism().withPartitioning(desiredHashVariables); + StreamProperties sourceProperties = derivePropertiesRecursively(groupId.getSource(), context); + if (requiredProperties.isSatisfiedBy(sourceProperties)) { + // Stream is already (locally) partitioned just as we want. + // In fact, there might be just a LocalExchange below and no Remote. For now, we give up in this situation anyway. To properly support such situation: + // 1. aggregation effectiveness estimation below need to consider the (helpful) fact that stream is already partitioned, so each operator will need less memory + // 2. if the local exchange becomes unnecessary (after we add a remove on top of it), it should be removed. What if the local exchange is somewhere further + // down the tree? + return Optional.empty(); + } + + double estimatedGroups = estimateGroupCount(desiredHashVariables, context.getStatsProvider().getStats(groupId.getSource())); + if (isNaN(estimatedGroups) || estimatedGroups * ANTI_SKEWNESS_MARGIN < maximalConcurrencyAfterRepartition(context)) { + // Desired hash symbols form too few groups. Hashing over them would harm concurrency. + // TODO instead of taking symbols with >GROUPING_SETS_SYMBOL_REQUIRED_FREQUENCY presence, we could take symbols from high freq to low until there are enough groups + return Optional.empty(); + } + + PlanNode source = groupId.getSource(); + + // Above we only checked the data is not yet locally partitioned and it could be already globally partitioned (but not locally). TODO avoid remote exchange in this case + // TODO If the aggregation memory requirements are only slightly above `maxPartialAggregationMemoryUsage`, adding only LocalExchange could be enough + source = partitionedExchange( + context.getIdAllocator().getNextId(), + REMOTE_STREAMING, + source, + new PartitioningScheme( + Partitioning.create(FIXED_HASH_DISTRIBUTION, desiredHashVariables), + source.getOutputVariables())); + + source = partitionedExchange( + context.getIdAllocator().getNextId(), + LOCAL, + source, + new PartitioningScheme( + Partitioning.create(FIXED_HASH_DISTRIBUTION, desiredHashVariables), + source.getOutputVariables())); + + PlanNode newGroupId = groupId.replaceChildren(ImmutableList.of(source)); + PlanNode newAggregation = aggregation.replaceChildren(ImmutableList.of(newGroupId)); + + return Optional.of(newAggregation); + } + + private int maximalConcurrencyAfterRepartition(Context context) + { + return getTaskConcurrency(context.getSession()) * taskCountEstimator.estimateHashedTaskCount(context.getSession()); + } + + private double estimateAggregationMemoryRequirements(Set groupingKeys, + GroupIdNode groupId, + Multiset groupingSetHistogram, + Context context) + { + checkArgument(Objects.equals(groupingSetHistogram.elementSet(), groupingKeys)); // Otherwise math below would be off-topic + + PlanNodeStatsEstimate sourceStats = context.getStatsProvider().getStats(groupId.getSource()); + double keysMemoryRequirements = 0; + + for (List groupingSet : groupId.getGroupingSets()) { + List sourceVariables = groupingSet.stream() + .map(groupId.getGroupingColumns()::get) + .collect(toImmutableList()); + + double keyWidth = sourceStats.getOutputSizeForVariables(sourceVariables) / sourceStats.getOutputRowCount(); + double keyNdv = min(estimateGroupCount(sourceVariables, sourceStats), sourceStats.getOutputRowCount()); + + keysMemoryRequirements += keyWidth * keyNdv; + } + + // TODO consider also memory requirements for aggregation values + return keysMemoryRequirements; + } + + private double estimateGroupCount(List variables, PlanNodeStatsEstimate statsEstimate) + { + return variables.stream() + .map(statsEstimate::getVariableStatistics) + .mapToDouble(this::ndvIncludingNull) + // This assumes no correlation, maximum number of aggregation keys + .reduce(1, (a, b) -> a * b); + } + + private double ndvIncludingNull(VariableStatsEstimate variableStatsEstimate) + { + if (variableStatsEstimate.getNullsFraction() == 0.) { + return variableStatsEstimate.getDistinctValuesCount(); + } + return variableStatsEstimate.getDistinctValuesCount() + 1; + } + + private StreamProperties derivePropertiesRecursively(PlanNode node, Context context) + { + PlanNode resolvedPlanNode = context.getLookup().resolve(node); + List inputProperties = resolvedPlanNode.getSources().stream() + .map(source -> derivePropertiesRecursively(source, context)) + .collect(toImmutableList()); + return deriveProperties(resolvedPlanNode, inputProperties, metadata, context.getSession()); + } + } +} diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/GroupIdNode.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/GroupIdNode.java index 1b91e75e7dea5..b77a1fd2e7194 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/GroupIdNode.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/GroupIdNode.java @@ -81,7 +81,9 @@ public GroupIdNode( { super(sourceLocation, id, statsEquivalentPlanNode); this.source = requireNonNull(source); - this.groupingSets = listOfListsCopy(requireNonNull(groupingSets, "groupingSets is null")); + checkArgument(requireNonNull(groupingSets, "groupingSets is null").size() > 1, + "groupingSets must have more than one grouping set, passed set was [%s]", groupingSets); + this.groupingSets = listOfListsCopy(groupingSets); this.groupingColumns = ImmutableMap.copyOf(requireNonNull(groupingColumns)); this.aggregationArguments = ImmutableList.copyOf(aggregationArguments); this.groupIdVariable = requireNonNull(groupIdVariable); diff --git a/presto-main/src/main/java/com/facebook/presto/testing/LocalQueryRunner.java b/presto-main/src/main/java/com/facebook/presto/testing/LocalQueryRunner.java index 8d2a9ac22e5bb..60626a7dbdc14 100644 --- a/presto-main/src/main/java/com/facebook/presto/testing/LocalQueryRunner.java +++ b/presto-main/src/main/java/com/facebook/presto/testing/LocalQueryRunner.java @@ -353,6 +353,7 @@ public class LocalQueryRunner private final NodeSpillConfig nodeSpillConfig; private final NodeSchedulerConfig nodeSchedulerConfig; private final FragmentStatsProvider fragmentStatsProvider; + private final TaskManagerConfig taskManagerConfig; private boolean printPlan; private final PlanChecker distributedPlanChecker; @@ -378,19 +379,25 @@ public LocalQueryRunner(Session defaultSession, FeaturesConfig featuresConfig, F public LocalQueryRunner(Session defaultSession, FeaturesConfig featuresConfig, FunctionsConfig functionsConfig, NodeSpillConfig nodeSpillConfig, boolean withInitialTransaction, boolean alwaysRevokeMemory) { - this(defaultSession, featuresConfig, functionsConfig, nodeSpillConfig, withInitialTransaction, alwaysRevokeMemory, 1, new ObjectMapper()); + this(defaultSession, featuresConfig, functionsConfig, nodeSpillConfig, withInitialTransaction, alwaysRevokeMemory, new ObjectMapper()); } public LocalQueryRunner(Session defaultSession, FeaturesConfig featuresConfig, FunctionsConfig functionsConfig, NodeSpillConfig nodeSpillConfig, boolean withInitialTransaction, boolean alwaysRevokeMemory, ObjectMapper objectMapper) { - this(defaultSession, featuresConfig, functionsConfig, nodeSpillConfig, withInitialTransaction, alwaysRevokeMemory, 1, objectMapper); + this(defaultSession, featuresConfig, functionsConfig, nodeSpillConfig, withInitialTransaction, alwaysRevokeMemory, 1, objectMapper, new TaskManagerConfig().setTaskConcurrency(4)); } - private LocalQueryRunner(Session defaultSession, FeaturesConfig featuresConfig, FunctionsConfig functionsConfig, NodeSpillConfig nodeSpillConfig, boolean withInitialTransaction, boolean alwaysRevokeMemory, int nodeCountForStats, ObjectMapper objectMapper) + public LocalQueryRunner(Session defaultSession, FeaturesConfig featuresConfig, FunctionsConfig functionsConfig, NodeSpillConfig nodeSpillConfig, boolean withInitialTransaction, boolean alwaysRevokeMemory, ObjectMapper objectMapper, TaskManagerConfig taskManagerConfig) + { + this(defaultSession, featuresConfig, functionsConfig, nodeSpillConfig, withInitialTransaction, alwaysRevokeMemory, 1, objectMapper, taskManagerConfig); + } + + private LocalQueryRunner(Session defaultSession, FeaturesConfig featuresConfig, FunctionsConfig functionsConfig, NodeSpillConfig nodeSpillConfig, boolean withInitialTransaction, boolean alwaysRevokeMemory, int nodeCountForStats, ObjectMapper objectMapper, TaskManagerConfig taskManagerConfig) { requireNonNull(defaultSession, "defaultSession is null"); checkArgument(!defaultSession.getTransactionId().isPresent() || !withInitialTransaction, "Already in transaction"); + this.taskManagerConfig = taskManagerConfig; this.nodeSpillConfig = requireNonNull(nodeSpillConfig, "nodeSpillConfig is null"); this.alwaysRevokeMemory = alwaysRevokeMemory; this.notificationExecutor = newCachedThreadPool(daemonThreadsNamed("local-query-runner-executor-%s")); @@ -624,7 +631,7 @@ public static LocalQueryRunner queryRunnerWithInitialTransaction(Session default public static LocalQueryRunner queryRunnerWithFakeNodeCountForStats(Session defaultSession, int nodeCount) { - return new LocalQueryRunner(defaultSession, new FeaturesConfig(), new FunctionsConfig(), new NodeSpillConfig(), false, false, nodeCount, new ObjectMapper()); + return new LocalQueryRunner(defaultSession, new FeaturesConfig(), new FunctionsConfig(), new NodeSpillConfig(), false, false, nodeCount, new ObjectMapper(), new TaskManagerConfig().setTaskConcurrency(4)); } @Override @@ -987,7 +994,7 @@ private List createDrivers(Session session, Plan plan, OutputFactory out pageFunctionCompiler, joinFilterFunctionCompiler, new IndexJoinLookupStats(), - new TaskManagerConfig().setTaskConcurrency(4), + taskManagerConfig, new MemoryManagerConfig(), new FunctionsConfig(), spillerFactory, @@ -1153,7 +1160,8 @@ public List getPlanOptimizers(boolean noExchange) taskCountEstimator, partitioningProviderManager, featuresConfig, - expressionOptimizerManager).getPlanningTimeOptimizers()); + expressionOptimizerManager, + taskManagerConfig).getPlanningTimeOptimizers()); return planOptimizers.build(); } diff --git a/presto-main/src/test/java/com/facebook/presto/sql/analyzer/TestFeaturesConfig.java b/presto-main/src/test/java/com/facebook/presto/sql/analyzer/TestFeaturesConfig.java index 4a07a25abbd04..6ebe9120e8a1a 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/analyzer/TestFeaturesConfig.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/analyzer/TestFeaturesConfig.java @@ -253,7 +253,8 @@ public void testDefaults() .setNativeExecutionScaleWritersThreadsEnabled(false) .setEnhancedCTESchedulingEnabled(true) .setExpressionOptimizerName("default") - .setExcludeInvalidWorkerSessionProperties(false)); + .setExcludeInvalidWorkerSessionProperties(false) + .setAddExchangeBelowPartialAggregationOverGroupId(false)); } @Test @@ -456,6 +457,7 @@ public void testExplicitPropertyMappings() .put("enhanced-cte-scheduling-enabled", "false") .put("expression-optimizer-name", "custom") .put("exclude-invalid-worker-session-properties", "true") + .put("optimizer.add-exchange-below-partial-aggregation-over-group-id", "true") .build(); FeaturesConfig expected = new FeaturesConfig() @@ -655,7 +657,8 @@ public void testExplicitPropertyMappings() .setNativeExecutionScaleWritersThreadsEnabled(true) .setEnhancedCTESchedulingEnabled(false) .setExpressionOptimizerName("custom") - .setExcludeInvalidWorkerSessionProperties(true); + .setExcludeInvalidWorkerSessionProperties(true) + .setAddExchangeBelowPartialAggregationOverGroupId(true); assertFullMapping(properties, expected); } diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/TestLogicalAddExchangesBelowPartialAggregationOverGroupIdRuleSet.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/TestLogicalAddExchangesBelowPartialAggregationOverGroupIdRuleSet.java new file mode 100644 index 0000000000000..5cffdb0dcd22d --- /dev/null +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/TestLogicalAddExchangesBelowPartialAggregationOverGroupIdRuleSet.java @@ -0,0 +1,99 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.sql.planner; + +import com.facebook.presto.Session; +import com.facebook.presto.execution.TaskManagerConfig; +import com.facebook.presto.spi.WarningCollector; +import com.facebook.presto.sql.planner.assertions.BasePlanTest; +import com.facebook.presto.sql.planner.assertions.PlanAssert; +import com.facebook.presto.sql.planner.plan.GroupIdNode; +import com.facebook.presto.testing.LocalQueryRunner; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import com.google.common.collect.ImmutableSet; +import io.airlift.units.DataSize; +import org.testng.annotations.Test; + +import static com.facebook.presto.SystemSessionProperties.ADD_EXCHANGE_BELOW_PARTIAL_AGGREGATION_OVER_GROUP_ID; +import static com.facebook.presto.SystemSessionProperties.MERGE_AGGREGATIONS_WITH_AND_WITHOUT_FILTER; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.anyTree; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.exchange; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.expression; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.node; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.project; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.tableScan; +import static com.facebook.presto.sql.planner.plan.ExchangeNode.Scope.LOCAL; +import static com.facebook.presto.sql.planner.plan.ExchangeNode.Scope.REMOTE_STREAMING; +import static com.facebook.presto.sql.planner.plan.ExchangeNode.Type.REPARTITION; +import static io.airlift.units.DataSize.Unit.KILOBYTE; +import static io.airlift.units.DataSize.Unit.MEGABYTE; + +public class TestLogicalAddExchangesBelowPartialAggregationOverGroupIdRuleSet + extends BasePlanTest +{ + public TestLogicalAddExchangesBelowPartialAggregationOverGroupIdRuleSet() + { + super(TestLogicalAddExchangesBelowPartialAggregationOverGroupIdRuleSet::setup); + } + + private static LocalQueryRunner setup() + { + // We set available max-partial-aggregation-memory to a low value to allow the rule to trigger for the TPCH tiny scale factor + TaskManagerConfig taskManagerConfig = new TaskManagerConfig().setMaxPartialAggregationMemoryUsage(DataSize.succinctDataSize(1, KILOBYTE)); + return createQueryRunner(ImmutableMap.of(ADD_EXCHANGE_BELOW_PARTIAL_AGGREGATION_OVER_GROUP_ID, "true"), taskManagerConfig); + } + + @Test + public void testRollup() + { + assertDistributedPlan("SELECT orderkey, suppkey, partkey, sum(quantity) from lineitem GROUP BY ROLLUP(orderkey, suppkey, partkey)", + anyTree(node(GroupIdNode.class, + // Since 'orderkey' will be the variable with the highest frequency, we repartition on it + anyTree(exchange(LOCAL, REPARTITION, ImmutableList.of(), ImmutableSet.of("orderkey"), + exchange(REMOTE_STREAMING, REPARTITION, ImmutableList.of(), ImmutableSet.of("orderkey"), + anyTree(tableScan("lineitem", ImmutableMap.of("orderkey", "orderkey"))))))))); + } + + @Test + public void testNegativeCases() + { + // MERGE_AGGREGATIONS_WITH_AND_WITHOUT_FILTER adds a Project for an 'expr' that is pass-through through the GroupIdNode node + // The Rule does not apply when such a variable is used in an Aggregation but not in the GroupId grouping set + Session enableMergeAggregationWithAndWithoutFilter = Session.builder(getQueryRunner().getDefaultSession()) + .setSystemProperty(MERGE_AGGREGATIONS_WITH_AND_WITHOUT_FILTER, "true") + .build(); + String sql = "select partkey, sum(quantity), sum(quantity) filter (where discount > 0.1) from lineitem group by grouping sets((), (partkey))"; + assertDistributedPlan(sql, enableMergeAggregationWithAndWithoutFilter, + anyTree(node(GroupIdNode.class, + project(ImmutableMap.of("partkey", expression("partkey"), "quantity", expression("quantity"), "expr", expression("discount > DOUBLE'0.1'")), + tableScan("lineitem", + ImmutableMap.of("partkey", "partkey", "quantity", "quantity", "discount", "discount")))))); + + // Rule does not apply when aggregation will be effective due to a sufficiently high max-partial-aggregation-memory + TaskManagerConfig taskManagerConfig = new TaskManagerConfig().setMaxPartialAggregationMemoryUsage(DataSize.succinctDataSize(1, MEGABYTE)); + try (LocalQueryRunner queryRunner = createQueryRunner(ImmutableMap.of(ADD_EXCHANGE_BELOW_PARTIAL_AGGREGATION_OVER_GROUP_ID, "true"), taskManagerConfig)) { + queryRunner.inTransaction(queryRunner.getDefaultSession(), transactionSession -> { + Plan plan = queryRunner.createPlan(transactionSession, + "SELECT orderkey, suppkey, partkey, sum(quantity) from lineitem GROUP BY ROLLUP(orderkey, suppkey, partkey)", + WarningCollector.NOOP); + + PlanAssert.assertPlan(transactionSession, queryRunner.getMetadata(), queryRunner.getStatsCalculator(), plan, + anyTree(node(GroupIdNode.class, + tableScan("lineitem", ImmutableMap.of("orderkey", "orderkey"))))); + return null; + }); + } + } +} diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/BasePlanTest.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/BasePlanTest.java index 8194b060e0d42..45deffb292fa7 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/BasePlanTest.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/BasePlanTest.java @@ -21,6 +21,7 @@ import com.facebook.presto.common.type.TestingTypeDeserializer; import com.facebook.presto.common.type.TestingTypeManager; import com.facebook.presto.common.type.Type; +import com.facebook.presto.execution.TaskManagerConfig; import com.facebook.presto.metadata.Metadata; import com.facebook.presto.spi.ConnectorId; import com.facebook.presto.spi.WarningCollector; @@ -104,6 +105,11 @@ protected static ObjectMapper createObjectMapper() } protected static LocalQueryRunner createQueryRunner(Map sessionProperties) + { + return createQueryRunner(sessionProperties, new TaskManagerConfig().setTaskConcurrency(1)); + } + + protected static LocalQueryRunner createQueryRunner(Map sessionProperties, TaskManagerConfig taskManagerConfig) { Session.SessionBuilder sessionBuilder = testSessionBuilder() .setCatalog("local") @@ -112,7 +118,14 @@ protected static LocalQueryRunner createQueryRunner(Map sessionP sessionProperties.entrySet().forEach(entry -> sessionBuilder.setSystemProperty(entry.getKey(), entry.getValue())); - LocalQueryRunner queryRunner = new LocalQueryRunner(sessionBuilder.build(), new FeaturesConfig(), new FunctionsConfig(), new NodeSpillConfig(), false, false, createObjectMapper()); + LocalQueryRunner queryRunner = new LocalQueryRunner(sessionBuilder.build(), + new FeaturesConfig(), + new FunctionsConfig(), + new NodeSpillConfig(), + false, + false, + createObjectMapper(), + taskManagerConfig); queryRunner.createCatalog(queryRunner.getDefaultSession().getCatalog().get(), new TpchConnectorFactory(1), diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/ExchangeMatcher.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/ExchangeMatcher.java index cdda3f16c3070..88419b759f046 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/ExchangeMatcher.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/ExchangeMatcher.java @@ -17,10 +17,14 @@ import com.facebook.presto.cost.StatsProvider; import com.facebook.presto.metadata.Metadata; import com.facebook.presto.spi.plan.PlanNode; +import com.facebook.presto.spi.relation.RowExpression; +import com.facebook.presto.spi.relation.VariableReferenceExpression; import com.facebook.presto.sql.planner.assertions.PlanMatchPattern.Ordering; import com.facebook.presto.sql.planner.plan.ExchangeNode; +import com.google.common.collect.ImmutableSet; import java.util.List; +import java.util.Set; import static com.facebook.presto.sql.planner.assertions.MatchResult.NO_MATCH; import static com.facebook.presto.sql.planner.assertions.Util.orderingSchemeMatches; @@ -34,12 +38,14 @@ final class ExchangeMatcher private final ExchangeNode.Scope scope; private final ExchangeNode.Type type; private final List orderBy; + private final Set partitionedBy; - public ExchangeMatcher(ExchangeNode.Scope scope, ExchangeNode.Type type, List orderBy) + public ExchangeMatcher(ExchangeNode.Scope scope, ExchangeNode.Type type, List orderBy, Set partitionedBy) { this.scope = scope; this.type = type; this.orderBy = requireNonNull(orderBy, "orderBy is null"); + this.partitionedBy = requireNonNull(partitionedBy, "partitionedBy is null"); } @Override @@ -69,6 +75,18 @@ public MatchResult detailMatches(PlanNode node, StatsProvider stats, Session ses } } + if (!partitionedBy.isEmpty()) { + Set partitionedColumns = exchangeNode.getPartitioningScheme().getPartitioning().getArguments().stream() + .map(RowExpression.class::cast) + .map(VariableReferenceExpression.class::cast) + .map(VariableReferenceExpression::getName) + .collect(ImmutableSet.toImmutableSet()); + + if (!partitionedColumns.containsAll(partitionedBy)) { + return NO_MATCH; + } + } + return MatchResult.match(); } diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/PlanMatchPattern.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/PlanMatchPattern.java index d06a36ae81279..8702e3c0f7a1b 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/PlanMatchPattern.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/PlanMatchPattern.java @@ -512,9 +512,14 @@ public static PlanMatchPattern exchange(ExchangeNode.Scope scope, ExchangeNode.T } public static PlanMatchPattern exchange(ExchangeNode.Scope scope, ExchangeNode.Type type, List orderBy, PlanMatchPattern... sources) + { + return exchange(scope, type, orderBy, ImmutableSet.of(), sources); + } + + public static PlanMatchPattern exchange(ExchangeNode.Scope scope, ExchangeNode.Type type, List orderBy, Set partitionedBy, PlanMatchPattern... sources) { return node(ExchangeNode.class, sources) - .with(new ExchangeMatcher(scope, type, orderBy)); + .with(new ExchangeMatcher(scope, type, orderBy, partitionedBy)); } public static PlanMatchPattern union(PlanMatchPattern... sources) diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestAddExchangesBelowPartialAggregationOverGroupIdRuleSet.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestAddExchangesBelowPartialAggregationOverGroupIdRuleSet.java new file mode 100644 index 0000000000000..4239877e95318 --- /dev/null +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestAddExchangesBelowPartialAggregationOverGroupIdRuleSet.java @@ -0,0 +1,224 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.sql.planner.iterative.rule; + +import com.facebook.presto.cost.PlanNodeStatsEstimate; +import com.facebook.presto.cost.TaskCountEstimator; +import com.facebook.presto.cost.VariableStatsEstimate; +import com.facebook.presto.execution.TaskManagerConfig; +import com.facebook.presto.spi.plan.Partitioning; +import com.facebook.presto.spi.plan.PartitioningScheme; +import com.facebook.presto.spi.plan.PlanNode; +import com.facebook.presto.spi.plan.PlanNodeId; +import com.facebook.presto.spi.relation.VariableReferenceExpression; +import com.facebook.presto.sql.planner.assertions.GroupIdMatcher; +import com.facebook.presto.sql.planner.iterative.rule.test.BaseRuleTest; +import com.facebook.presto.sql.planner.iterative.rule.test.RuleAssert; +import com.facebook.presto.sql.planner.iterative.rule.test.RuleTester; +import com.facebook.presto.sql.planner.plan.GroupIdNode; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import com.google.common.collect.ImmutableSet; +import org.testng.annotations.DataProvider; +import org.testng.annotations.Test; + +import java.util.Optional; + +import static com.facebook.presto.SystemSessionProperties.ADD_EXCHANGE_BELOW_PARTIAL_AGGREGATION_OVER_GROUP_ID; +import static com.facebook.presto.common.type.BigintType.BIGINT; +import static com.facebook.presto.spi.plan.AggregationNode.Step.PARTIAL; +import static com.facebook.presto.sql.planner.SystemPartitioningHandle.FIXED_ARBITRARY_DISTRIBUTION; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.aggregation; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.exchange; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.expression; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.node; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.project; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.singleGroupingSet; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.values; +import static com.facebook.presto.sql.planner.plan.AssignmentUtils.identityAssignments; +import static com.facebook.presto.sql.planner.plan.ExchangeNode.Scope.LOCAL; +import static com.facebook.presto.sql.planner.plan.ExchangeNode.Scope.REMOTE_STREAMING; +import static com.facebook.presto.sql.planner.plan.ExchangeNode.Type.GATHER; +import static com.facebook.presto.sql.planner.plan.ExchangeNode.Type.REPARTITION; + +public class TestAddExchangesBelowPartialAggregationOverGroupIdRuleSet + extends BaseRuleTest +{ + private static AddExchangesBelowPartialAggregationOverGroupIdRuleSet.AddExchangesBelowExchangePartialAggregationGroupId belowExchangeRule(RuleTester ruleTester) + { + TaskCountEstimator taskCountEstimator = new TaskCountEstimator(() -> 4); + TaskManagerConfig taskManagerConfig = new TaskManagerConfig(); + return new AddExchangesBelowPartialAggregationOverGroupIdRuleSet( + taskCountEstimator, + taskManagerConfig, + ruleTester.getMetadata() + ).belowExchangeRule(); + } + + private static AddExchangesBelowPartialAggregationOverGroupIdRuleSet.AddExchangesBelowProjectionPartialAggregationGroupId belowProjectionRule(RuleTester ruleTester) + { + TaskCountEstimator taskCountEstimator = new TaskCountEstimator(() -> 4); + TaskManagerConfig taskManagerConfig = new TaskManagerConfig(); + return new AddExchangesBelowPartialAggregationOverGroupIdRuleSet( + taskCountEstimator, + taskManagerConfig, + ruleTester.getMetadata() + ).belowProjectionRule(); + } + + @DataProvider + public static Object[][] testDataProvider() + { + return new Object[][] { + {1000.0, 10_000.0, 1_000_000.0, "groupingKey3"}, + {1000.0, 2_000_000.0, 1_000_000.0, "groupingKey2"}, + {1000.0, 1000.0, 1000.0, "groupingKey1"} + }; + } + + @DataProvider + public static Object[][] testDataProviderMissingStats() + { + return new Object[][] { + {Double.NaN, 10_000.0, 1_000_000.0}, + {1000.0, Double.NaN, 1_000_000.0}, + {1000.0, 10_000.0, Double.NaN} + }; + } + + @Test(dataProvider = "testDataProvider") + public void testAddExchangesWithoutProjection(double groupingKey1NDV, double groupingKey2NDV, double groupingKey3NDV, String expectedRepartitionSymbol) + { + buildRuleAssert(groupingKey1NDV, groupingKey2NDV, groupingKey3NDV, false) + .matches(exchange( + REMOTE_STREAMING, + GATHER, + aggregation( + singleGroupingSet(ImmutableList.of("groupingKey1", "groupingKey2", "groupingKey3", "groupId")), + ImmutableMap.of(), + ImmutableList.of(), + ImmutableMap.of(), + Optional.empty(), + PARTIAL, + node(GroupIdNode.class, + exchange( + LOCAL, + REPARTITION, + ImmutableList.of(), + ImmutableSet.of(expectedRepartitionSymbol), + exchange( + REMOTE_STREAMING, + REPARTITION, + ImmutableList.of(), + ImmutableSet.of(expectedRepartitionSymbol), + values("groupingKey1", "groupingKey2", "groupingKey3")))) + .with(new GroupIdMatcher(ImmutableList.of( + ImmutableList.of("groupingKey1", "groupingKey2"), + ImmutableList.of("groupingKey1", "groupingKey3")), ImmutableMap.of(), "groupId"))))); + } + + @Test(dataProvider = "testDataProvider") + public void testAddExchangesWithProjection(double groupingKey1NDV, double groupingKey2NDV, double groupingKey3NDV, String expectedRepartitionSymbol) + { + buildRuleAssert(groupingKey1NDV, groupingKey2NDV, groupingKey3NDV, true) + .matches(exchange( + REMOTE_STREAMING, + GATHER, + project( + ImmutableMap.of( + "groupingKey1", expression("groupingKey1"), + "groupingKey2", expression("groupingKey2"), + "groupingKey3", expression("groupingKey3")), + aggregation( + singleGroupingSet(ImmutableList.of("groupingKey1", "groupingKey2", "groupingKey3", "groupId")), + ImmutableMap.of(), + ImmutableList.of(), + ImmutableMap.of(), + Optional.empty(), + PARTIAL, + node(GroupIdNode.class, + exchange( + LOCAL, + REPARTITION, + ImmutableList.of(), + ImmutableSet.of(expectedRepartitionSymbol), + exchange( + REMOTE_STREAMING, + REPARTITION, + ImmutableList.of(), + ImmutableSet.of(expectedRepartitionSymbol), + values("groupingKey1", "groupingKey2", "groupingKey3")))) + .with(new GroupIdMatcher(ImmutableList.of( + ImmutableList.of("groupingKey1", "groupingKey2"), + ImmutableList.of("groupingKey1", "groupingKey3")), ImmutableMap.of(), "groupId")))))); + } + + @Test(dataProvider = "testDataProviderMissingStats") + public void testDoesNotFireIfAnySourceSymbolIsMissingStats(double groupingKey1NDV, double groupingKey2NDV, double groupingKey3NDV) + { + buildRuleAssert(groupingKey1NDV, groupingKey2NDV, groupingKey3NDV, true).doesNotFire(); + buildRuleAssert(groupingKey1NDV, groupingKey2NDV, groupingKey3NDV, false).doesNotFire(); + } + + @Test + public void testDoesNotFireIfSessionPropertyIsDisabled() + { + buildRuleAssert(1000D, 1000D, 1000D, false) + .setSystemProperty(ADD_EXCHANGE_BELOW_PARTIAL_AGGREGATION_OVER_GROUP_ID, "false") + .doesNotFire(); + } + + private RuleAssert buildRuleAssert(double groupingKey1NDV, double groupingKey2NDV, double groupingKey3NDV, boolean withProjection) + { + RuleTester ruleTester = tester(); + String groupIdSourceId = "groupIdSourceId"; + return ruleTester.assertThat(withProjection ? belowProjectionRule(ruleTester) : belowExchangeRule(ruleTester)) + .setSystemProperty(ADD_EXCHANGE_BELOW_PARTIAL_AGGREGATION_OVER_GROUP_ID, "true") + .overrideStats(groupIdSourceId, PlanNodeStatsEstimate + .builder() + .setOutputRowCount(100_000_000) + .addVariableStatistics(ImmutableMap.of( + new VariableReferenceExpression(Optional.empty(), "groupingKey1", BIGINT), VariableStatsEstimate.builder().setDistinctValuesCount(groupingKey1NDV).build(), + new VariableReferenceExpression(Optional.empty(), "groupingKey2", BIGINT), VariableStatsEstimate.builder().setDistinctValuesCount(groupingKey2NDV).build(), + new VariableReferenceExpression(Optional.empty(), "groupingKey3", BIGINT), VariableStatsEstimate.builder().setDistinctValuesCount(groupingKey3NDV).build())) + .build()) + .on(p -> { + VariableReferenceExpression groupingKey1 = p.variable("groupingKey1", BIGINT); + VariableReferenceExpression groupingKey2 = p.variable("groupingKey2", BIGINT); + VariableReferenceExpression groupingKey3 = p.variable("groupingKey3", BIGINT); + VariableReferenceExpression groupId = p.variable("groupId", BIGINT); + + PlanNode partialAgg = p.aggregation(builder -> builder + .singleGroupingSet(groupingKey1, groupingKey2, groupingKey3, groupId) + .step(PARTIAL) + .source(p.groupId( + ImmutableList.of( + ImmutableList.of(groupingKey1, groupingKey2), + ImmutableList.of(groupingKey1, groupingKey3)), + ImmutableList.of(), + groupId, + p.values(new PlanNodeId(groupIdSourceId), groupingKey1, groupingKey2, groupingKey3)))); + + return p.exchange( + exchangeBuilder -> exchangeBuilder + .scope(REMOTE_STREAMING) + .partitioningScheme(new PartitioningScheme(Partitioning.create( + FIXED_ARBITRARY_DISTRIBUTION, + ImmutableList.of()), + ImmutableList.copyOf(ImmutableList.of(groupingKey1, groupingKey2, groupingKey3, groupId)))) + .addInputsSet(groupingKey1, groupingKey2, groupingKey3, groupId) + .addSource(withProjection ? p.project(identityAssignments(partialAgg.getOutputVariables()), partialAgg) : partialAgg)); + }); + } +} diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/test/PlanBuilder.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/test/PlanBuilder.java index aa633c33a8624..6e541d36606aa 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/test/PlanBuilder.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/test/PlanBuilder.java @@ -76,6 +76,7 @@ import com.facebook.presto.sql.planner.plan.AssignUniqueId; import com.facebook.presto.sql.planner.plan.EnforceSingleRowNode; import com.facebook.presto.sql.planner.plan.ExchangeNode; +import com.facebook.presto.sql.planner.plan.GroupIdNode; import com.facebook.presto.sql.planner.plan.IndexJoinNode; import com.facebook.presto.sql.planner.plan.IndexSourceNode; import com.facebook.presto.sql.planner.plan.LateralJoinNode; @@ -97,6 +98,7 @@ import java.util.ArrayList; import java.util.Arrays; +import java.util.Collection; import java.util.HashMap; import java.util.LinkedHashMap; import java.util.List; @@ -126,8 +128,10 @@ import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.base.Preconditions.checkState; import static com.google.common.collect.ImmutableList.toImmutableList; +import static com.google.common.collect.ImmutableMap.toImmutableMap; import static java.lang.String.format; import static java.util.Collections.emptyList; +import static java.util.function.Function.identity; public class PlanBuilder { @@ -1047,4 +1051,29 @@ public PlanNodeIdAllocator getIdAllocator() { return idAllocator; } + + public GroupIdNode groupId(List> groupingSets, List aggregationArguments, VariableReferenceExpression groupIdSymbol, PlanNode source) + { + Map groupingColumns = groupingSets.stream() + .flatMap(Collection::stream) + .distinct() + .collect(toImmutableMap(identity(), identity())); + return groupId(groupingSets, groupingColumns, aggregationArguments, groupIdSymbol, source); + } + + public GroupIdNode groupId(List> groupingSets, + Map groupingColumns, + List aggregationArguments, + VariableReferenceExpression groupIdSymbol, + PlanNode source) + { + return new GroupIdNode( + Optional.empty(), + idAllocator.getNextId(), + source, + groupingSets, + groupingColumns, + aggregationArguments, + groupIdSymbol); + } } diff --git a/presto-tests/src/main/java/com/facebook/presto/tests/AbstractTestQueryFramework.java b/presto-tests/src/main/java/com/facebook/presto/tests/AbstractTestQueryFramework.java index bedd15a5d9e10..cdc974260c6f3 100644 --- a/presto-tests/src/main/java/com/facebook/presto/tests/AbstractTestQueryFramework.java +++ b/presto-tests/src/main/java/com/facebook/presto/tests/AbstractTestQueryFramework.java @@ -22,6 +22,7 @@ import com.facebook.presto.cost.CostComparator; import com.facebook.presto.cost.TaskCountEstimator; import com.facebook.presto.execution.QueryManagerConfig; +import com.facebook.presto.execution.TaskManagerConfig; import com.facebook.presto.metadata.InMemoryNodeManager; import com.facebook.presto.metadata.Metadata; import com.facebook.presto.nodeManager.PluginNodeManager; @@ -576,7 +577,8 @@ private QueryExplainer getQueryExplainer() featuresConfig, new ExpressionOptimizerManager( new PluginNodeManager(new InMemoryNodeManager()), - queryRunner.getMetadata().getFunctionAndTypeManager())) + queryRunner.getMetadata().getFunctionAndTypeManager()), + new TaskManagerConfig()) .getPlanningTimeOptimizers(); return new QueryExplainer( optimizers,