From 40d84a406553fae46e1fed53f89ff0956855e678 Mon Sep 17 00:00:00 2001 From: Pranjal Shankhdhar Date: Tue, 6 Dec 2022 13:38:47 -0800 Subject: [PATCH] Optimize plan hashing in case of union/join nodes --- .../sql/planner/CanonicalPlanGenerator.java | 102 ++++++++++++------ .../TestHistoryBasedStatsTracking.java | 28 +++++ 2 files changed, 99 insertions(+), 31 deletions(-) diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/CanonicalPlanGenerator.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/CanonicalPlanGenerator.java index 8407cf5396981..a06fb3a04287d 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/CanonicalPlanGenerator.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/CanonicalPlanGenerator.java @@ -34,6 +34,7 @@ import com.facebook.presto.spi.relation.CallExpression; import com.facebook.presto.spi.relation.RowExpression; import com.facebook.presto.spi.relation.VariableReferenceExpression; +import com.facebook.presto.sql.planner.optimizations.PlanNodeSearcher; import com.facebook.presto.sql.planner.plan.GroupIdNode; import com.facebook.presto.sql.planner.plan.InternalPlanVisitor; import com.facebook.presto.sql.planner.plan.JoinNode; @@ -63,6 +64,7 @@ import java.util.Optional; import java.util.Set; import java.util.Stack; +import java.util.stream.Collectors; import static com.facebook.presto.common.function.OperatorType.EQUAL; import static com.facebook.presto.common.plan.PlanCanonicalizationStrategy.DEFAULT; @@ -247,7 +249,7 @@ public Optional visitJoin(JoinNode node, Context context) allFilters.add(filter); }); } - top.getSources().forEach(source -> { + for (PlanNode source : top.getSources()) { if (source instanceof JoinNode && ((JoinNode) source).getType().equals(node.getType()) && shouldMergeJoinNodes(node.getType())) { @@ -256,20 +258,16 @@ && shouldMergeJoinNodes(node.getType())) { else { sources.add(source); } - }); + } } // Sort sources if all are INNER, or full outer join of 2 nodes if (shouldMergeJoinNodes(node.getType()) || (node.getType().equals(JoinNode.Type.FULL) && sources.size() == 2)) { - Map sourceKeys = new IdentityHashMap<>(); - for (PlanNode source : sources) { - Optional canonicalSource = generateCanonicalPlan(source, strategy, objectMapper); - if (!canonicalSource.isPresent()) { - return Optional.empty(); - } - sourceKeys.put(source, canonicalSource.get().toString(objectMapper)); + Optional> sourceIndexes = orderSources(sources); + if (!sourceIndexes.isPresent()) { + return Optional.empty(); } - sources.sort(comparing(sourceKeys::get)); + sources = sourceIndexes.get().stream().map(sources::get).collect(toImmutableList()); } ImmutableList.Builder newSources = ImmutableList.builder(); @@ -311,34 +309,26 @@ public Optional visitUnion(UnionNode node, Context context) return Optional.empty(); } - // We want to order sources in a consistent manner. Variable names and plan node ids can mess with that because of - // our stateful canonicalization using `variableAllocator` and `planNodeIdAllocator`. - // So, we first try to canonicalize each source independently. They may have conflicting variable names, but we only use it - // to decide order, and then canonicalize them properly again. - // This can lead to O(n * h) time complexity, where n is number of plan nodes, and h is height of plan tree. This - // is at par with other pieces like hashing each plan node. - Multimap sourceToPosition = TreeMultimap.create(); - for (int i = 0; i < node.getSources().size(); ++i) { - PlanNode source = node.getSources().get(i); - Optional canonicalSource = generateCanonicalPlan(source, strategy, objectMapper); - if (!canonicalSource.isPresent()) { - return Optional.empty(); - } - sourceToPosition.put(canonicalSource.get().toString(objectMapper), i); + Optional> sourceIndexes = orderSources(node.getSources()); + if (!sourceIndexes.isPresent()) { + return Optional.empty(); } - ImmutableList.Builder sources = ImmutableList.builder(); + ImmutableList.Builder canonicalSources = ImmutableList.builder(); ImmutableList.Builder outputVariables = ImmutableList.builder(); ImmutableMap.Builder> outputsToInputs = ImmutableMap.builder(); - sourceToPosition.forEach((ignored, index) -> { - PlanNode canonicalSource = node.getSources().get(index).accept(this, context).get(); - sources.add(canonicalSource); - }); + for (Integer sourceIndex : sourceIndexes.get()) { + Optional canonicalSource = node.getSources().get(sourceIndex).accept(this, context); + if (!canonicalSource.isPresent()) { + return Optional.empty(); + } + canonicalSources.add(canonicalSource.get()); + } node.getVariableMapping().forEach((outputVariable, sourceVariables) -> { ImmutableList.Builder newSourceVariablesBuilder = ImmutableList.builder(); - sourceToPosition.forEach((ignored, index) -> { + sourceIndexes.get().forEach(index -> { newSourceVariablesBuilder.add(inlineAndCanonicalize(context.getExpressions(), sourceVariables.get(index))); }); ImmutableList newSourceVariables = newSourceVariablesBuilder.build(); @@ -351,7 +341,7 @@ public Optional visitUnion(UnionNode node, Context context) PlanNode result = new UnionNode( Optional.empty(), planNodeidAllocator.getNextId(), - sources.build(), + canonicalSources.build(), outputVariables.build().stream().sorted().collect(toImmutableList()), ImmutableSortedMap.copyOf(outputsToInputs.build())); @@ -599,6 +589,56 @@ public Optional visitProject(ProjectNode node, Context context) return Optional.of(canonicalPlan); } + // Variable names and plan node ids can change with what order we process nodes because of our + // stateful canonicalization using `variableAllocator` and `planNodeIdAllocator`. + // We want to order sources in a consistent manner, because the order matters when hashing plan. + // Returns a list of indices in input sources array, with a canonical order. + private Optional> orderSources(List sources) + { + // Try heuristic where we sort sources by the tables they scan. + Optional> sourcesByTables = orderSourcesByTables(sources); + if (sourcesByTables.isPresent()) { + return sourcesByTables; + } + + // We canonicalize each source independently, and use its representation to order sources. + Multimap sourceToPosition = TreeMultimap.create(); + for (int i = 0; i < sources.size(); ++i) { + Optional canonicalSource = generateCanonicalPlan(sources.get(i), strategy, objectMapper); + if (!canonicalSource.isPresent()) { + return Optional.empty(); + } + sourceToPosition.put(canonicalSource.get().toString(objectMapper), i); + } + return Optional.of(sourceToPosition.values().stream().collect(toImmutableList())); + } + + // Order sources by list of tables they use. If any 2 sources are using the same set of tables, we give up + // and return Optional.empty(). + // Returns a list of indices in input sources array, with a canonical order + private Optional> orderSourcesByTables(List sources) + { + Multimap sourceToPosition = TreeMultimap.create(); + for (int i = 0; i < sources.size(); ++i) { + List tables = new ArrayList<>(); + + PlanNodeSearcher.searchFrom(sources.get(i)) + .where(node -> node instanceof TableScanNode) + .findAll() + .forEach(node -> tables.add(((TableScanNode) node).getTable().getConnectorHandle().toString())); + sourceToPosition.put(tables.stream().sorted().collect(Collectors.joining(",")), i); + } + String lastIdentifier = ","; + for (Map.Entry entry : sourceToPosition.entries()) { + String identifier = entry.getKey(); + if (lastIdentifier.equals(identifier)) { + return Optional.empty(); + } + lastIdentifier = identifier; + } + return Optional.of(sourceToPosition.values().stream().collect(toImmutableList())); + } + private static class CanonicalWriterTarget extends TableWriterNode.WriterTarget { diff --git a/presto-tests/src/test/java/com/facebook/presto/execution/TestHistoryBasedStatsTracking.java b/presto-tests/src/test/java/com/facebook/presto/execution/TestHistoryBasedStatsTracking.java index ee505c1f5958d..dbff7df7abf16 100644 --- a/presto-tests/src/test/java/com/facebook/presto/execution/TestHistoryBasedStatsTracking.java +++ b/presto-tests/src/test/java/com/facebook/presto/execution/TestHistoryBasedStatsTracking.java @@ -135,6 +135,34 @@ public void testUnion() anyTree(node(ProjectNode.class, node(FilterNode.class, anyTree(anyTree(any()), anyTree(any())))).withOutputRowCount(4))); } + @Test + public void testUnionMultiple() + { + assertPlan( + "SELECT * FROM nation where substr(name, 1, 1) = 'A' UNION ALL " + + "SELECT * FROM nation where substr(name, 1, 1) = 'B' UNION ALL " + + "SELECT * FROM nation where substr(name, 1, 1) = 'C'", + anyTree(node(ExchangeNode.class, anyTree(any()), anyTree(any()), anyTree(any())).withOutputRowCount(Double.NaN))); + + executeAndTrackHistory("SELECT * FROM nation where substr(name, 1, 1) = 'A' UNION ALL " + + "SELECT * FROM nation where substr(name, 1, 1) = 'B' UNION ALL " + + "SELECT * FROM nation where substr(name, 1, 1) = 'C'"); + assertPlan( + "SELECT * FROM nation where substr(name, 1, 1) = 'B' UNION ALL " + + "SELECT * FROM nation where substr(name, 1, 1) = 'C' UNION ALL " + + "SELECT * FROM nation where substr(name, 1, 1) = 'A'", + anyTree(node(ExchangeNode.class, anyTree(any()), anyTree(any()), anyTree(any())).withOutputRowCount(5))); + + assertPlan( + "SELECT nationkey FROM nation where substr(name, 1, 1) = 'A' UNION ALL SELECT nationkey FROM customer where nationkey < 10", + anyTree(node(ExchangeNode.class, anyTree(any()), anyTree(any())).withOutputRowCount(Double.NaN))); + + executeAndTrackHistory("SELECT nationkey FROM nation where substr(name, 1, 1) = 'A' UNION ALL SELECT nationkey FROM customer where nationkey < 10"); + assertPlan( + " SELECT nationkey FROM customer where nationkey < 10 UNION ALL SELECT nationkey FROM nation where substr(name, 1, 1) = 'A'", + anyTree(node(ExchangeNode.class, anyTree(any()), anyTree(any())).withOutputRowCount(601))); + } + @Test public void testJoin() {