Skip to content

Commit

Permalink
Improve cte eduling by scheduling subgraphs independently
Browse files Browse the repository at this point in the history
Subscribers: O4263 subscribe to presto changes

JIRA Issues: PRESTO-6697

Differential Revision: https://code.uberinternal.com/D13417433
  • Loading branch information
jaystarshot committed Apr 3, 2024
1 parent 3d071e0 commit 97fb98e
Show file tree
Hide file tree
Showing 8 changed files with 196 additions and 43 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,20 @@ protected QueryRunner createQueryRunner()
Optional.empty());
}

@Test
public void testCteExecutionWhereOneCteRemovedBySimplifyEmptyInputRule()
{
String sql = "WITH t as(select orderkey, count(*) as count from (select orderkey from orders where false) group by orderkey)," +
"t1 as (SELECT * FROM orders)," +
" b AS ((SELECT orderkey FROM t) UNION (SELECT orderkey FROM t1)) " +
"SELECT * FROM b";
QueryRunner queryRunner = getQueryRunner();
compareResults(queryRunner.execute(getMaterializedSession(),
sql),
queryRunner.execute(getSession(),
sql));
}

@Test
public void testSimplePersistentCte()
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -204,25 +204,33 @@ public PlanNode visitTableFinish(TableFinishNode node, RewriteContext<FragmentPr
@Override
public PlanNode visitSequence(SequenceNode node, RewriteContext<FragmentProperties> context)
{
// Since this is topologically sorted by the LogicalCtePlanner, need to make sure that execution order follows
// Can be optimized further to avoid non dependents from getting blocked
int cteProducerCount = node.getCteProducers().size();
checkArgument(cteProducerCount >= 1, "Sequence Node has 0 CTE producers");
PlanNode source = node.getCteProducers().get(cteProducerCount - 1);
FragmentProperties childProperties = new FragmentProperties(new PartitioningScheme(
Partitioning.create(SINGLE_DISTRIBUTION, ImmutableList.of()),
source.getOutputVariables()));
SubPlan lastSubPlan = buildSubPlan(source, childProperties, context);

for (int sourceIndex = cteProducerCount - 2; sourceIndex >= 0; sourceIndex--) {
source = node.getCteProducers().get(sourceIndex);
childProperties = new FragmentProperties(new PartitioningScheme(
// To ensure that the execution order is maintained, we use an independent dependency graph.
// This approach creates subgraphs sequentially, enhancing control over the execution flow. However, there are optimization opportunities:
// 1. Can consider blocking only the CTE (Common Table Expression) consumer stages that are in a reading state.
// This approach sounds good on paper may not be ideal as it can block the entire query, leading to resource wastage since no progress can be made until the writing operations are complete.
// 2. ToDo: Another improvement will be to schedule the execution of subgraphs based on their order in the overall execution plan instead of a topological sorting done here
// but that needs change to plan section framework for it to be able to handle the same child planSection.
List<List<PlanNode>> independentCteProducerSubgraphs = node.getIndependentCteProducers();
for (List<PlanNode> cteProducerSubgraph : independentCteProducerSubgraphs) {
int cteProducerCount = cteProducerSubgraph.size();
checkArgument(cteProducerCount >= 1, "CteProducer subgraph has 0 CTE producers");
PlanNode source = cteProducerSubgraph.get(cteProducerCount - 1);
FragmentProperties childProperties = new FragmentProperties(new PartitioningScheme(
Partitioning.create(SINGLE_DISTRIBUTION, ImmutableList.of()),
source.getOutputVariables()));
childProperties.addChildren(ImmutableList.of(lastSubPlan));
lastSubPlan = buildSubPlan(source, childProperties, context);
SubPlan lastSubPlan = buildSubPlan(source, childProperties, context);
for (int sourceIndex = cteProducerCount - 2; sourceIndex >= 0; sourceIndex--) {
source = cteProducerSubgraph.get(sourceIndex);
childProperties = new FragmentProperties(new PartitioningScheme(
Partitioning.create(SINGLE_DISTRIBUTION, ImmutableList.of()),
source.getOutputVariables()));
childProperties.addChildren(ImmutableList.of(lastSubPlan));
lastSubPlan = buildSubPlan(source, childProperties, context);
}
// This makes sure that the sectionedPlans generated in com.facebook.presto.execution.scheduler.StreamingPlanSection
// are independent and thus could be scheduled concurrently
context.get().addChildren(ImmutableList.of(lastSubPlan));
}
context.get().addChildren(ImmutableList.of(lastSubPlan));
return node.getPrimarySource().accept(this, context);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
import com.facebook.presto.sql.planner.plan.SimplePlanRewriter;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.graph.Graph;
import com.google.common.graph.GraphBuilder;
import com.google.common.graph.MutableGraph;
import com.google.common.graph.Traverser;
Expand Down Expand Up @@ -135,8 +136,11 @@ public PlanNode transformPersistentCtes(Session session, PlanNode root)
return transformedCte;
}
isPlanRewritten = true;
SequenceNode sequenceNode = new SequenceNode(root.getSourceLocation(), planNodeIdAllocator.getNextId(), topologicalOrderedList,
transformedCte.getSources().get(0));
SequenceNode sequenceNode = new SequenceNode(root.getSourceLocation(),
planNodeIdAllocator.getNextId(),
topologicalOrderedList,
transformedCte.getSources().get(0),
context.createIndexedGraphFromTopologicallySortedCteProducers(topologicalOrderedList));
return root.replaceChildren(Arrays.asList(sequenceNode));
}

Expand Down Expand Up @@ -278,16 +282,17 @@ public static class LogicalCteOptimizerContext
public Map<String, CteProducerNode> cteProducerMap;

// a -> b indicates that b needs to be processed before a
MutableGraph<String> graph;
public Stack<String> activeCteStack;
private MutableGraph<String> cteDependencyGraph;

public Set<String> complexCtes;
private Stack<String> activeCteStack;

private Set<String> complexCtes;

public LogicalCteOptimizerContext()
{
cteProducerMap = new HashMap<>();
// The cte graph will never have cycles because sql won't allow it
graph = GraphBuilder.directed().build();
cteDependencyGraph = GraphBuilder.directed().build();
activeCteStack = new Stack<>();
complexCtes = new HashSet<>();
}
Expand Down Expand Up @@ -319,9 +324,9 @@ public Optional<String> peekActiveCte()

public void addDependency(String currentCte)
{
graph.addNode(currentCte);
cteDependencyGraph.addNode(currentCte);
Optional<String> parentCte = peekActiveCte();
parentCte.ifPresent(s -> graph.putEdge(currentCte, s));
parentCte.ifPresent(s -> cteDependencyGraph.putEdge(currentCte, s));
}

public void addComplexCte(String cteId)
Expand All @@ -342,9 +347,30 @@ public boolean isComplexCte(String cteId)
public List<PlanNode> getTopologicalOrdering()
{
ImmutableList.Builder<PlanNode> topSortedCteProducerListBuilder = ImmutableList.builder();
Traverser.forGraph(graph).depthFirstPostOrder(graph.nodes())
Traverser.forGraph(cteDependencyGraph).depthFirstPostOrder(cteDependencyGraph.nodes())
.forEach(cteName -> topSortedCteProducerListBuilder.add(cteProducerMap.get(cteName)));
return topSortedCteProducerListBuilder.build();
}

public Graph<Integer> createIndexedGraphFromTopologicallySortedCteProducers(List<PlanNode> topologicalSortedCteProducerList)
{
Map<String, Integer> cteIdToProducerIndexMap = new HashMap<>();
MutableGraph<Integer> indexGraph = GraphBuilder
.directed()
.expectedNodeCount(topologicalSortedCteProducerList.size())
.build();
for (int i = 0; i < topologicalSortedCteProducerList.size(); i++) {
cteIdToProducerIndexMap.put(((CteProducerNode) topologicalSortedCteProducerList.get(i)).getCteId(), i);
indexGraph.addNode(i);
}

// Populate the new graph with edges based on the index mapping
for (String cteId : cteDependencyGraph.nodes()) {
cteDependencyGraph.successors(cteId).forEach(successor ->
indexGraph.putEdge(cteIdToProducerIndexMap.get(cteId), cteIdToProducerIndexMap.get(successor)));
}
return indexGraph;
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -532,7 +532,7 @@ public PlanNode visitSequence(SequenceNode node, RewriteContext<Set<VariableRefe
.map(leftSource -> context.rewrite(leftSource, leftInputs)).collect(toImmutableList());
Set<VariableReferenceExpression> rightInputs = ImmutableSet.copyOf(node.getPrimarySource().getOutputVariables());
PlanNode primarySource = context.rewrite(node.getPrimarySource(), rightInputs);
return new SequenceNode(node.getSourceLocation(), node.getId(), cteProducers, primarySource);
return new SequenceNode(node.getSourceLocation(), node.getId(), cteProducers, primarySource, node.getCteDependencyGraph());
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,9 +48,11 @@

import java.util.ArrayList;
import java.util.Collections;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.stream.IntStream;

import static com.facebook.presto.SystemSessionProperties.isSimplifyPlanWithEmptyInputEnabled;
Expand Down Expand Up @@ -209,23 +211,33 @@ else if (isEmptyNode(rewrittenRight)) {
public PlanNode visitSequence(SequenceNode node, RewriteContext<Void> context)
{
List<PlanNode> cteProducers = node.getCteProducers();
List<PlanNode> newSequenceChildrenList = new ArrayList<>();
List<PlanNode> newCteProducerList = new ArrayList<>();
// Visit in the order of execution
Set<Integer> removedIndexes = new HashSet<>();
for (int i = cteProducers.size() - 1; i >= 0; i--) {
PlanNode rewrittenProducer = context.rewrite(cteProducers.get(i));
if (!isEmptyNode(rewrittenProducer)) {
newSequenceChildrenList.add(rewrittenProducer);
newCteProducerList.add(rewrittenProducer);
}
else {
this.planChanged = true;
removedIndexes.add(i);
}
}
PlanNode rewrittenPrimarySource = context.rewrite(node.getPrimarySource());
if (isEmptyNode(rewrittenPrimarySource) || newSequenceChildrenList.isEmpty()) {
if (isEmptyNode(rewrittenPrimarySource) || newCteProducerList.isEmpty()) {
return rewrittenPrimarySource;
}
if (!this.planChanged) {
return node;
}
// Reverse order for execution
Collections.reverse(newSequenceChildrenList);
// Add the primary source at the end of the list
newSequenceChildrenList.add(rewrittenPrimarySource);
return node.replaceChildren(newSequenceChildrenList);
Collections.reverse(newCteProducerList);
return new SequenceNode(node.getSourceLocation(),
idAllocator.getNextId(),
ImmutableList.copyOf(newCteProducerList),
rewrittenPrimarySource,
node.removeCteProducersFromCteDependencyGraph(removedIndexes));
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,7 @@ public PlanNode visitSequence(SequenceNode node, RewriteContext<Void> context)
SimplePlanRewriter.rewriteWith(new Rewriter(types, functionAndTypeManager, warningCollector), c))
.collect(Collectors.toList());
PlanNode primarySource = context.rewrite(node.getPrimarySource());
return new SequenceNode(node.getSourceLocation(), node.getId(), cteProducers, primarySource);
return new SequenceNode(node.getSourceLocation(), node.getId(), cteProducers, primarySource, node.getCteDependencyGraph());
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import com.google.common.graph.GraphBuilder;
import org.testng.annotations.AfterClass;
import org.testng.annotations.BeforeClass;
import org.testng.annotations.Test;
Expand Down Expand Up @@ -531,7 +532,7 @@ public void testSequence()
Optional.empty(),
new PlanNodeId("sequence"),
ImmutableList.of(cteProducerNode1, cteProducerNode2),
joinNode);
joinNode, GraphBuilder.directed().build());

// Define cost of sequence children
Map<String, PlanCostEstimate> costs = ImmutableMap.of(
Expand Down
Loading

0 comments on commit 97fb98e

Please sign in to comment.