Skip to content

Commit

Permalink
Improve cte scheduling by scheduling subgraphs independently
Browse files Browse the repository at this point in the history
This improves latency of cte materialized queries by scheduling subgraphs independently
  • Loading branch information
jaystarshot committed Apr 5, 2024
1 parent b13af96 commit 75b3799
Show file tree
Hide file tree
Showing 8 changed files with 197 additions and 45 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,20 @@ protected QueryRunner createQueryRunner()
Optional.empty());
}

@Test
public void testCteExecutionWhereOneCteRemovedBySimplifyEmptyInputRule()
{
String sql = "WITH t as(select orderkey, count(*) as count from (select orderkey from orders where false) group by orderkey)," +
"t1 as (SELECT * FROM orders)," +
" b AS ((SELECT orderkey FROM t) UNION (SELECT orderkey FROM t1)) " +
"SELECT * FROM b";
QueryRunner queryRunner = getQueryRunner();
compareResults(queryRunner.execute(getMaterializedSession(),
sql),
queryRunner.execute(getSession(),
sql));
}

@Test
public void testSimplePersistentCte()
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -204,25 +204,33 @@ public PlanNode visitTableFinish(TableFinishNode node, RewriteContext<FragmentPr
@Override
public PlanNode visitSequence(SequenceNode node, RewriteContext<FragmentProperties> context)
{
// Since this is topologically sorted by the LogicalCtePlanner, need to make sure that execution order follows
// Can be optimized further to avoid non dependents from getting blocked
int cteProducerCount = node.getCteProducers().size();
checkArgument(cteProducerCount >= 1, "Sequence Node has 0 CTE producers");
PlanNode source = node.getCteProducers().get(cteProducerCount - 1);
FragmentProperties childProperties = new FragmentProperties(new PartitioningScheme(
Partitioning.create(SINGLE_DISTRIBUTION, ImmutableList.of()),
source.getOutputVariables()));
SubPlan lastSubPlan = buildSubPlan(source, childProperties, context);

for (int sourceIndex = cteProducerCount - 2; sourceIndex >= 0; sourceIndex--) {
source = node.getCteProducers().get(sourceIndex);
childProperties = new FragmentProperties(new PartitioningScheme(
// To ensure that the execution order is maintained, we use an independent dependency graph.
// This approach creates subgraphs sequentially, enhancing control over the execution flow. However, there are optimization opportunities:
// 1. Can consider blocking only the CTEConsumer stages that are in a reading state.
// This approach sounds good on paper may not be ideal as it can block the entire query, leading to resource wastage since no progress can be made until the writing operations are complete.
// 2. ToDo: Another improvement will be to schedule the execution of subgraphs based on their order in the overall execution plan instead of a topological sorting done here
// but that needs change to plan section framework for it to be able to handle the same child planSection.
List<List<PlanNode>> independentCteProducerSubgraphs = node.getIndependentCteProducers();
for (List<PlanNode> cteProducerSubgraph : independentCteProducerSubgraphs) {
int cteProducerCount = cteProducerSubgraph.size();
checkArgument(cteProducerCount >= 1, "CteProducer subgraph has 0 CTE producers");
PlanNode source = cteProducerSubgraph.get(cteProducerCount - 1);
FragmentProperties childProperties = new FragmentProperties(new PartitioningScheme(
Partitioning.create(SINGLE_DISTRIBUTION, ImmutableList.of()),
source.getOutputVariables()));
childProperties.addChildren(ImmutableList.of(lastSubPlan));
lastSubPlan = buildSubPlan(source, childProperties, context);
SubPlan lastSubPlan = buildSubPlan(source, childProperties, context);
for (int sourceIndex = cteProducerCount - 2; sourceIndex >= 0; sourceIndex--) {
source = cteProducerSubgraph.get(sourceIndex);
childProperties = new FragmentProperties(new PartitioningScheme(
Partitioning.create(SINGLE_DISTRIBUTION, ImmutableList.of()),
source.getOutputVariables()));
childProperties.addChildren(ImmutableList.of(lastSubPlan));
lastSubPlan = buildSubPlan(source, childProperties, context);
}
// This makes sure that the sectionedPlans generated in com.facebook.presto.execution.scheduler.StreamingPlanSection
// are independent and thus could be scheduled concurrently
context.get().addChildren(ImmutableList.of(lastSubPlan));
}
context.get().addChildren(ImmutableList.of(lastSubPlan));
return node.getPrimarySource().accept(this, context);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
import com.facebook.presto.sql.planner.plan.SimplePlanRewriter;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.graph.Graph;
import com.google.common.graph.GraphBuilder;
import com.google.common.graph.MutableGraph;
import com.google.common.graph.Traverser;
Expand Down Expand Up @@ -135,8 +136,11 @@ public PlanNode transformPersistentCtes(Session session, PlanNode root)
return transformedCte;
}
isPlanRewritten = true;
SequenceNode sequenceNode = new SequenceNode(root.getSourceLocation(), planNodeIdAllocator.getNextId(), topologicalOrderedList,
transformedCte.getSources().get(0));
SequenceNode sequenceNode = new SequenceNode(root.getSourceLocation(),
planNodeIdAllocator.getNextId(),
topologicalOrderedList,
transformedCte.getSources().get(0),
context.createIndexedGraphFromTopologicallySortedCteProducers(topologicalOrderedList));
return root.replaceChildren(Arrays.asList(sequenceNode));
}

Expand Down Expand Up @@ -271,23 +275,25 @@ public PlanNode visitApply(ApplyNode node, RewriteContext<LogicalCteOptimizerCon
node.getCorrelation(),
node.getOriginSubqueryError(),
node.getMayParticipateInAntiJoin());
}}
}
}

public static class LogicalCteOptimizerContext
{
public Map<String, CteProducerNode> cteProducerMap;

// a -> b indicates that b needs to be processed before a
MutableGraph<String> graph;
public Stack<String> activeCteStack;
// a -> b indicates that a needs to be processed before b
private MutableGraph<String> cteDependencyGraph;

private Stack<String> activeCteStack;

public Set<String> complexCtes;
private Set<String> complexCtes;

public LogicalCteOptimizerContext()
{
cteProducerMap = new HashMap<>();
// The cte graph will never have cycles because sql won't allow it
graph = GraphBuilder.directed().build();
cteDependencyGraph = GraphBuilder.directed().build();
activeCteStack = new Stack<>();
complexCtes = new HashSet<>();
}
Expand Down Expand Up @@ -319,9 +325,10 @@ public Optional<String> peekActiveCte()

public void addDependency(String currentCte)
{
graph.addNode(currentCte);
cteDependencyGraph.addNode(currentCte);
Optional<String> parentCte = peekActiveCte();
parentCte.ifPresent(s -> graph.putEdge(currentCte, s));
// (current -> parentCte) this indicates that currentCte must be processed first
parentCte.ifPresent(s -> cteDependencyGraph.putEdge(currentCte, s));
}

public void addComplexCte(String cteId)
Expand All @@ -342,9 +349,29 @@ public boolean isComplexCte(String cteId)
public List<PlanNode> getTopologicalOrdering()
{
ImmutableList.Builder<PlanNode> topSortedCteProducerListBuilder = ImmutableList.builder();
Traverser.forGraph(graph).depthFirstPostOrder(graph.nodes())
Traverser.forGraph(cteDependencyGraph).depthFirstPostOrder(cteDependencyGraph.nodes())
.forEach(cteName -> topSortedCteProducerListBuilder.add(cteProducerMap.get(cteName)));
return topSortedCteProducerListBuilder.build();
}

public Graph<Integer> createIndexedGraphFromTopologicallySortedCteProducers(List<PlanNode> topologicalSortedCteProducerList)
{
Map<String, Integer> cteIdToProducerIndexMap = new HashMap<>();
MutableGraph<Integer> indexGraph = GraphBuilder
.directed()
.expectedNodeCount(topologicalSortedCteProducerList.size())
.build();
for (int i = 0; i < topologicalSortedCteProducerList.size(); i++) {
cteIdToProducerIndexMap.put(((CteProducerNode) topologicalSortedCteProducerList.get(i)).getCteId(), i);
indexGraph.addNode(i);
}

// Populate the new graph with edges based on the index mapping
for (String cteId : cteDependencyGraph.nodes()) {
cteDependencyGraph.successors(cteId).forEach(successor ->
indexGraph.putEdge(cteIdToProducerIndexMap.get(cteId), cteIdToProducerIndexMap.get(successor)));
}
return indexGraph;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -532,7 +532,7 @@ public PlanNode visitSequence(SequenceNode node, RewriteContext<Set<VariableRefe
.map(leftSource -> context.rewrite(leftSource, leftInputs)).collect(toImmutableList());
Set<VariableReferenceExpression> rightInputs = ImmutableSet.copyOf(node.getPrimarySource().getOutputVariables());
PlanNode primarySource = context.rewrite(node.getPrimarySource(), rightInputs);
return new SequenceNode(node.getSourceLocation(), node.getId(), cteProducers, primarySource);
return new SequenceNode(node.getSourceLocation(), node.getId(), cteProducers, primarySource, node.getCteDependencyGraph());
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,9 +48,11 @@

import java.util.ArrayList;
import java.util.Collections;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.stream.IntStream;

import static com.facebook.presto.SystemSessionProperties.isSimplifyPlanWithEmptyInputEnabled;
Expand Down Expand Up @@ -209,23 +211,33 @@ else if (isEmptyNode(rewrittenRight)) {
public PlanNode visitSequence(SequenceNode node, RewriteContext<Void> context)
{
List<PlanNode> cteProducers = node.getCteProducers();
List<PlanNode> newSequenceChildrenList = new ArrayList<>();
List<PlanNode> newCteProducerList = new ArrayList<>();
// Visit in the order of execution
Set<Integer> removedIndexes = new HashSet<>();
for (int i = cteProducers.size() - 1; i >= 0; i--) {
PlanNode rewrittenProducer = context.rewrite(cteProducers.get(i));
if (!isEmptyNode(rewrittenProducer)) {
newSequenceChildrenList.add(rewrittenProducer);
newCteProducerList.add(rewrittenProducer);
}
else {
this.planChanged = true;
removedIndexes.add(i);
}
}
PlanNode rewrittenPrimarySource = context.rewrite(node.getPrimarySource());
if (isEmptyNode(rewrittenPrimarySource) || newSequenceChildrenList.isEmpty()) {
if (isEmptyNode(rewrittenPrimarySource) || newCteProducerList.isEmpty()) {
return rewrittenPrimarySource;
}
if (!this.planChanged) {
return node;
}
// Reverse order for execution
Collections.reverse(newSequenceChildrenList);
// Add the primary source at the end of the list
newSequenceChildrenList.add(rewrittenPrimarySource);
return node.replaceChildren(newSequenceChildrenList);
Collections.reverse(newCteProducerList);
return new SequenceNode(node.getSourceLocation(),
idAllocator.getNextId(),
ImmutableList.copyOf(newCteProducerList),
rewrittenPrimarySource,
node.removeCteProducersFromCteDependencyGraph(removedIndexes));
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,7 @@ public PlanNode visitSequence(SequenceNode node, RewriteContext<Void> context)
SimplePlanRewriter.rewriteWith(new Rewriter(types, functionAndTypeManager, warningCollector), c))
.collect(Collectors.toList());
PlanNode primarySource = context.rewrite(node.getPrimarySource());
return new SequenceNode(node.getSourceLocation(), node.getId(), cteProducers, primarySource);
return new SequenceNode(node.getSourceLocation(), node.getId(), cteProducers, primarySource, node.getCteDependencyGraph());
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,23 @@
import com.facebook.presto.spi.relation.VariableReferenceExpression;
import com.fasterxml.jackson.annotation.JsonCreator;
import com.fasterxml.jackson.annotation.JsonProperty;
import com.google.common.collect.ImmutableList;
import com.google.common.graph.Graph;
import com.google.common.graph.GraphBuilder;
import com.google.common.graph.ImmutableGraph;
import com.google.common.graph.MutableGraph;
import com.google.common.graph.Traverser;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.stream.Collectors;

import static com.google.common.base.Preconditions.checkArgument;

public class SequenceNode
extends InternalPlanNode
Expand All @@ -31,24 +44,32 @@ public class SequenceNode
private final List<PlanNode> cteProducers;
private final PlanNode primarySource;

// Directed graph of cte Producer Indexes (0 indexed)
// a -> b indicates that producer at index a needs to be processed before producer at index b
private final Graph<Integer> cteDependencyGraph;

@JsonCreator
public SequenceNode(Optional<SourceLocation> sourceLocation,
@JsonProperty("id") PlanNodeId planNodeId,
@JsonProperty("cteProducers") List<PlanNode> left,
@JsonProperty("primarySource") PlanNode primarySource)
@JsonProperty("cteProducers") List<PlanNode> cteProducerList,
@JsonProperty("primarySource") PlanNode primarySource,
Graph<Integer> cteDependencyGraph)
{
this(sourceLocation, planNodeId, Optional.empty(), left, primarySource);
this(sourceLocation, planNodeId, Optional.empty(), cteProducerList, primarySource, cteDependencyGraph);
}

public SequenceNode(Optional<SourceLocation> sourceLocation,
PlanNodeId planNodeId,
Optional<PlanNode> statsEquivalentPlanNode,
List<PlanNode> leftList,
PlanNode primarySource)
List<PlanNode> cteProducerList,
PlanNode primarySource,
Graph<Integer> cteDependencyGraph)
{
super(sourceLocation, planNodeId, statsEquivalentPlanNode);
this.cteProducers = leftList;
this.cteProducers = ImmutableList.copyOf(cteProducerList);
this.primarySource = primarySource;
checkArgument(cteDependencyGraph.isDirected(), "Sequence Node expects a directed graph");
this.cteDependencyGraph = ImmutableGraph.copyOf(cteDependencyGraph);
}

@JsonProperty
Expand Down Expand Up @@ -80,14 +101,83 @@ public List<VariableReferenceExpression> getOutputVariables()
@Override
public PlanNode replaceChildren(List<PlanNode> newChildren)
{
checkArgument(newChildren.size() == cteProducers.size() + 1, "expected newChildren to contain same number of nodes as current." +
" If the child count please update the dependency graph");
return new SequenceNode(newChildren.get(0).getSourceLocation(), getId(), getStatsEquivalentPlanNode(),
newChildren.subList(0, newChildren.size() - 1), newChildren.get(newChildren.size() - 1));
newChildren.subList(0, newChildren.size() - 1), newChildren.get(newChildren.size() - 1), cteDependencyGraph);
}

@Override
public PlanNode assignStatsEquivalentPlanNode(Optional<PlanNode> statsEquivalentPlanNode)
{
return new SequenceNode(getSourceLocation(), getId(), statsEquivalentPlanNode, cteProducers, this.getPrimarySource());
return new SequenceNode(getSourceLocation(), getId(), statsEquivalentPlanNode, cteProducers, this.getPrimarySource(), cteDependencyGraph);
}

public Graph<Integer> getCteDependencyGraph()
{
return cteDependencyGraph;
}

// Returns a Graph after removing indexes
public Graph<Integer> removeCteProducersFromCteDependencyGraph(Set<Integer> indexesToRemove)
{
Graph<Integer> originalGraph = getCteDependencyGraph();
MutableGraph newCteDependencyGraph = GraphBuilder.from(getCteDependencyGraph()).build();
Map<Integer, Integer> indexMapping = new HashMap<>();
// update the dependency graph remove the indexes from dependency graph
int removed = 0;
for (int prevIndex = 0; prevIndex < cteProducers.size(); prevIndex++) {
if (indexesToRemove.contains(prevIndex)) {
removed++;
}
else {
int newIndex = prevIndex - removed;
indexMapping.put(prevIndex, newIndex);
}
}
for (int oldIndex : originalGraph.nodes()) {
if (!indexesToRemove.contains(oldIndex)) {
Integer newIndex = indexMapping.get(oldIndex);
for (Integer successor : originalGraph.successors(oldIndex)) {
if (!indexesToRemove.contains(successor)) {
Integer newSuccessorIndex = indexMapping.get(successor);
newCteDependencyGraph.putEdge(newIndex, newSuccessorIndex);
}
}
}
}
return ImmutableGraph.copyOf(newCteDependencyGraph);
}

public List<List<PlanNode>> getIndependentCteProducers()
{
MutableGraph<Integer> undirectedDependencyGraph = GraphBuilder.undirected().allowsSelfLoops(false).build();
cteDependencyGraph.nodes().forEach(undirectedDependencyGraph::addNode);
cteDependencyGraph.edges().forEach(edge -> undirectedDependencyGraph.putEdge(edge.nodeU(), edge.nodeV()));

Set<Integer> visitedCteSet = new HashSet<>();
ImmutableList.Builder<List<PlanNode>> independentCteProducerList = ImmutableList.builder();
// Construct Subgraphs
List<PlanNode> result = new ArrayList<>();
for (Integer cteIndex : cteDependencyGraph.nodes()) {
if (!visitedCteSet.contains(cteIndex)) {
// Identify all nodes in the current connected component
Set<Integer> componentNodes = new HashSet<>();
Traverser.forGraph(undirectedDependencyGraph).breadthFirst(cteIndex).forEach(componentNode -> {
if (visitedCteSet.add(componentNode)) {
componentNodes.add(componentNode);
}
});

List<Integer> topSortedCteProducerList = new ArrayList<>();
Traverser.forGraph(cteDependencyGraph).depthFirstPostOrder(componentNodes)
.forEach(topSortedCteProducerList::add);
if (!topSortedCteProducerList.isEmpty()) {
independentCteProducerList.add(topSortedCteProducerList.stream().map(index -> cteProducers.get(index)).collect(Collectors.toList()));
}
}
}
return independentCteProducerList.build();
}

@Override
Expand Down
Loading

0 comments on commit 75b3799

Please sign in to comment.