Skip to content

Commit

Permalink
Optimize plan hashing in case of union/join nodes
Browse files Browse the repository at this point in the history
  • Loading branch information
pranjalssh committed Dec 10, 2022
1 parent 5a31a49 commit 40d84a4
Show file tree
Hide file tree
Showing 2 changed files with 99 additions and 31 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
import com.facebook.presto.spi.relation.CallExpression;
import com.facebook.presto.spi.relation.RowExpression;
import com.facebook.presto.spi.relation.VariableReferenceExpression;
import com.facebook.presto.sql.planner.optimizations.PlanNodeSearcher;
import com.facebook.presto.sql.planner.plan.GroupIdNode;
import com.facebook.presto.sql.planner.plan.InternalPlanVisitor;
import com.facebook.presto.sql.planner.plan.JoinNode;
Expand Down Expand Up @@ -63,6 +64,7 @@
import java.util.Optional;
import java.util.Set;
import java.util.Stack;
import java.util.stream.Collectors;

import static com.facebook.presto.common.function.OperatorType.EQUAL;
import static com.facebook.presto.common.plan.PlanCanonicalizationStrategy.DEFAULT;
Expand Down Expand Up @@ -247,7 +249,7 @@ public Optional<PlanNode> visitJoin(JoinNode node, Context context)
allFilters.add(filter);
});
}
top.getSources().forEach(source -> {
for (PlanNode source : top.getSources()) {
if (source instanceof JoinNode
&& ((JoinNode) source).getType().equals(node.getType())
&& shouldMergeJoinNodes(node.getType())) {
Expand All @@ -256,20 +258,16 @@ && shouldMergeJoinNodes(node.getType())) {
else {
sources.add(source);
}
});
}
}

// Sort sources if all are INNER, or full outer join of 2 nodes
if (shouldMergeJoinNodes(node.getType()) || (node.getType().equals(JoinNode.Type.FULL) && sources.size() == 2)) {
Map<PlanNode, String> sourceKeys = new IdentityHashMap<>();
for (PlanNode source : sources) {
Optional<CanonicalPlan> canonicalSource = generateCanonicalPlan(source, strategy, objectMapper);
if (!canonicalSource.isPresent()) {
return Optional.empty();
}
sourceKeys.put(source, canonicalSource.get().toString(objectMapper));
Optional<List<Integer>> sourceIndexes = orderSources(sources);
if (!sourceIndexes.isPresent()) {
return Optional.empty();
}
sources.sort(comparing(sourceKeys::get));
sources = sourceIndexes.get().stream().map(sources::get).collect(toImmutableList());
}

ImmutableList.Builder<PlanNode> newSources = ImmutableList.builder();
Expand Down Expand Up @@ -311,34 +309,26 @@ public Optional<PlanNode> visitUnion(UnionNode node, Context context)
return Optional.empty();
}

// We want to order sources in a consistent manner. Variable names and plan node ids can mess with that because of
// our stateful canonicalization using `variableAllocator` and `planNodeIdAllocator`.
// So, we first try to canonicalize each source independently. They may have conflicting variable names, but we only use it
// to decide order, and then canonicalize them properly again.
// This can lead to O(n * h) time complexity, where n is number of plan nodes, and h is height of plan tree. This
// is at par with other pieces like hashing each plan node.
Multimap<String, Integer> sourceToPosition = TreeMultimap.create();
for (int i = 0; i < node.getSources().size(); ++i) {
PlanNode source = node.getSources().get(i);
Optional<CanonicalPlan> canonicalSource = generateCanonicalPlan(source, strategy, objectMapper);
if (!canonicalSource.isPresent()) {
return Optional.empty();
}
sourceToPosition.put(canonicalSource.get().toString(objectMapper), i);
Optional<List<Integer>> sourceIndexes = orderSources(node.getSources());
if (!sourceIndexes.isPresent()) {
return Optional.empty();
}

ImmutableList.Builder<PlanNode> sources = ImmutableList.builder();
ImmutableList.Builder<PlanNode> canonicalSources = ImmutableList.builder();
ImmutableList.Builder<VariableReferenceExpression> outputVariables = ImmutableList.builder();
ImmutableMap.Builder<VariableReferenceExpression, List<VariableReferenceExpression>> outputsToInputs = ImmutableMap.builder();

sourceToPosition.forEach((ignored, index) -> {
PlanNode canonicalSource = node.getSources().get(index).accept(this, context).get();
sources.add(canonicalSource);
});
for (Integer sourceIndex : sourceIndexes.get()) {
Optional<PlanNode> canonicalSource = node.getSources().get(sourceIndex).accept(this, context);
if (!canonicalSource.isPresent()) {
return Optional.empty();
}
canonicalSources.add(canonicalSource.get());
}

node.getVariableMapping().forEach((outputVariable, sourceVariables) -> {
ImmutableList.Builder<VariableReferenceExpression> newSourceVariablesBuilder = ImmutableList.builder();
sourceToPosition.forEach((ignored, index) -> {
sourceIndexes.get().forEach(index -> {
newSourceVariablesBuilder.add(inlineAndCanonicalize(context.getExpressions(), sourceVariables.get(index)));
});
ImmutableList<VariableReferenceExpression> newSourceVariables = newSourceVariablesBuilder.build();
Expand All @@ -351,7 +341,7 @@ public Optional<PlanNode> visitUnion(UnionNode node, Context context)
PlanNode result = new UnionNode(
Optional.empty(),
planNodeidAllocator.getNextId(),
sources.build(),
canonicalSources.build(),
outputVariables.build().stream().sorted().collect(toImmutableList()),
ImmutableSortedMap.copyOf(outputsToInputs.build()));

Expand Down Expand Up @@ -599,6 +589,56 @@ public Optional<PlanNode> visitProject(ProjectNode node, Context context)
return Optional.of(canonicalPlan);
}

// Variable names and plan node ids can change with what order we process nodes because of our
// stateful canonicalization using `variableAllocator` and `planNodeIdAllocator`.
// We want to order sources in a consistent manner, because the order matters when hashing plan.
// Returns a list of indices in input sources array, with a canonical order.
private Optional<List<Integer>> orderSources(List<PlanNode> sources)
{
// Try heuristic where we sort sources by the tables they scan.
Optional<List<Integer>> sourcesByTables = orderSourcesByTables(sources);
if (sourcesByTables.isPresent()) {
return sourcesByTables;
}

// We canonicalize each source independently, and use its representation to order sources.
Multimap<String, Integer> sourceToPosition = TreeMultimap.create();
for (int i = 0; i < sources.size(); ++i) {
Optional<CanonicalPlan> canonicalSource = generateCanonicalPlan(sources.get(i), strategy, objectMapper);
if (!canonicalSource.isPresent()) {
return Optional.empty();
}
sourceToPosition.put(canonicalSource.get().toString(objectMapper), i);
}
return Optional.of(sourceToPosition.values().stream().collect(toImmutableList()));
}

// Order sources by list of tables they use. If any 2 sources are using the same set of tables, we give up
// and return Optional.empty().
// Returns a list of indices in input sources array, with a canonical order
private Optional<List<Integer>> orderSourcesByTables(List<PlanNode> sources)
{
Multimap<String, Integer> sourceToPosition = TreeMultimap.create();
for (int i = 0; i < sources.size(); ++i) {
List<String> tables = new ArrayList<>();

PlanNodeSearcher.searchFrom(sources.get(i))
.where(node -> node instanceof TableScanNode)
.findAll()
.forEach(node -> tables.add(((TableScanNode) node).getTable().getConnectorHandle().toString()));
sourceToPosition.put(tables.stream().sorted().collect(Collectors.joining(",")), i);
}
String lastIdentifier = ",";
for (Map.Entry<String, Integer> entry : sourceToPosition.entries()) {
String identifier = entry.getKey();
if (lastIdentifier.equals(identifier)) {
return Optional.empty();
}
lastIdentifier = identifier;
}
return Optional.of(sourceToPosition.values().stream().collect(toImmutableList()));
}

private static class CanonicalWriterTarget
extends TableWriterNode.WriterTarget
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,34 @@ public void testUnion()
anyTree(node(ProjectNode.class, node(FilterNode.class, anyTree(anyTree(any()), anyTree(any())))).withOutputRowCount(4)));
}

@Test
public void testUnionMultiple()
{
assertPlan(
"SELECT * FROM nation where substr(name, 1, 1) = 'A' UNION ALL " +
"SELECT * FROM nation where substr(name, 1, 1) = 'B' UNION ALL " +
"SELECT * FROM nation where substr(name, 1, 1) = 'C'",
anyTree(node(ExchangeNode.class, anyTree(any()), anyTree(any()), anyTree(any())).withOutputRowCount(Double.NaN)));

executeAndTrackHistory("SELECT * FROM nation where substr(name, 1, 1) = 'A' UNION ALL " +
"SELECT * FROM nation where substr(name, 1, 1) = 'B' UNION ALL " +
"SELECT * FROM nation where substr(name, 1, 1) = 'C'");
assertPlan(
"SELECT * FROM nation where substr(name, 1, 1) = 'B' UNION ALL " +
"SELECT * FROM nation where substr(name, 1, 1) = 'C' UNION ALL " +
"SELECT * FROM nation where substr(name, 1, 1) = 'A'",
anyTree(node(ExchangeNode.class, anyTree(any()), anyTree(any()), anyTree(any())).withOutputRowCount(5)));

assertPlan(
"SELECT nationkey FROM nation where substr(name, 1, 1) = 'A' UNION ALL SELECT nationkey FROM customer where nationkey < 10",
anyTree(node(ExchangeNode.class, anyTree(any()), anyTree(any())).withOutputRowCount(Double.NaN)));

executeAndTrackHistory("SELECT nationkey FROM nation where substr(name, 1, 1) = 'A' UNION ALL SELECT nationkey FROM customer where nationkey < 10");
assertPlan(
" SELECT nationkey FROM customer where nationkey < 10 UNION ALL SELECT nationkey FROM nation where substr(name, 1, 1) = 'A'",
anyTree(node(ExchangeNode.class, anyTree(any()), anyTree(any())).withOutputRowCount(601)));
}

@Test
public void testJoin()
{
Expand Down

0 comments on commit 40d84a4

Please sign in to comment.