Skip to content

Commit

Permalink
Support CTE nodes in cost calculator
Browse files Browse the repository at this point in the history
Summary:
Supporting CTE logical nodes in cost calculator is crucial for calculating cost,  this PR adds that.
  • Loading branch information
jaystarshot committed Feb 14, 2024
1 parent 6d749f9 commit 946d4cd
Show file tree
Hide file tree
Showing 17 changed files with 639 additions and 260 deletions.
513 changes: 265 additions & 248 deletions presto-hive/src/test/java/com/facebook/presto/hive/TestCteExecution.java

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,7 @@ public final class SystemSessionProperties
public static final String CTE_MATERIALIZATION_STRATEGY = "cte_materialization_strategy";
public static final String CTE_FILTER_AND_PROJECTION_PUSHDOWN_ENABLED = "cte_filter_and_projection_pushdown_enabled";
public static final String DEFAULT_JOIN_SELECTIVITY_COEFFICIENT = "default_join_selectivity_coefficient";
public static final String DEFAULT_WRITER_REPLICATION_COEFFICIENT = "default_writer_replication_coefficient";
public static final String PUSH_LIMIT_THROUGH_OUTER_JOIN = "push_limit_through_outer_join";
public static final String OPTIMIZE_CONSTANT_GROUPING_KEYS = "optimize_constant_grouping_keys";
public static final String MAX_CONCURRENT_MATERIALIZATIONS = "max_concurrent_materializations";
Expand Down Expand Up @@ -1091,6 +1092,11 @@ public SystemSessionProperties(
false,
value -> validateDoubleValueWithinSelectivityRange(value, DEFAULT_JOIN_SELECTIVITY_COEFFICIENT),
object -> object),
doubleProperty(
DEFAULT_WRITER_REPLICATION_COEFFICIENT,
"Replication coefficient for costing write operations",
featuresConfig.getDefaultWriterReplicationCoefficient(),
false),
booleanProperty(
PUSH_LIMIT_THROUGH_OUTER_JOIN,
"push limits to the outer side of an outer join",
Expand Down Expand Up @@ -2429,6 +2435,11 @@ public static boolean getCteFilterAndProjectionPushdownEnabled(Session session)
return session.getSystemProperty(CTE_FILTER_AND_PROJECTION_PUSHDOWN_ENABLED, Boolean.class);
}

public static double getCteProducerReplicationCoefficient(Session session)
{
return session.getSystemProperty(DEFAULT_WRITER_REPLICATION_COEFFICIENT, Double.class);
}

public static int getFilterAndProjectMinOutputPageRowCount(Session session)
{
return session.getSystemProperty(FILTER_AND_PROJECT_MIN_OUTPUT_PAGE_ROW_COUNT, Integer.class);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,15 @@

import com.facebook.presto.Session;
import com.facebook.presto.spi.plan.AggregationNode;
import com.facebook.presto.spi.plan.CteConsumerNode;
import com.facebook.presto.spi.plan.CteProducerNode;
import com.facebook.presto.spi.plan.FilterNode;
import com.facebook.presto.spi.plan.JoinDistributionType;
import com.facebook.presto.spi.plan.LimitNode;
import com.facebook.presto.spi.plan.OutputNode;
import com.facebook.presto.spi.plan.PlanNode;
import com.facebook.presto.spi.plan.ProjectNode;
import com.facebook.presto.spi.plan.SequenceNode;
import com.facebook.presto.spi.plan.TableScanNode;
import com.facebook.presto.spi.plan.UnionNode;
import com.facebook.presto.spi.plan.ValuesNode;
Expand All @@ -47,6 +50,7 @@
import java.util.Optional;
import java.util.stream.Stream;

import static com.facebook.presto.cost.CostCalculatorWithEstimatedExchanges.calculateCteProducerCost;
import static com.facebook.presto.cost.CostCalculatorWithEstimatedExchanges.calculateJoinInputCost;
import static com.facebook.presto.cost.CostCalculatorWithEstimatedExchanges.calculateLocalRepartitionCost;
import static com.facebook.presto.cost.CostCalculatorWithEstimatedExchanges.calculateRemoteGatherCost;
Expand Down Expand Up @@ -78,19 +82,21 @@ public CostCalculatorUsingExchanges(TaskCountEstimator taskCountEstimator)
@Override
public PlanCostEstimate calculateCost(PlanNode node, StatsProvider stats, CostProvider sourcesCosts, Session session)
{
CostEstimator costEstimator = new CostEstimator(stats, sourcesCosts, taskCountEstimator);
CostEstimator costEstimator = new CostEstimator(session, stats, sourcesCosts, taskCountEstimator);
return node.accept(costEstimator, null);
}

private static class CostEstimator
extends InternalPlanVisitor<PlanCostEstimate, Void>
{
private final Session session;
private final StatsProvider stats;
private final CostProvider sourcesCosts;
private final TaskCountEstimator taskCountEstimator;

CostEstimator(StatsProvider stats, CostProvider sourcesCosts, TaskCountEstimator taskCountEstimator)
CostEstimator(Session session, StatsProvider stats, CostProvider sourcesCosts, TaskCountEstimator taskCountEstimator)
{
this.session = requireNonNull(session, "session is null");
this.stats = requireNonNull(stats, "stats is null");
this.sourcesCosts = requireNonNull(sourcesCosts, "sourcesCosts is null");
this.taskCountEstimator = requireNonNull(taskCountEstimator, "taskCountEstimator is null");
Expand Down Expand Up @@ -156,6 +162,25 @@ public PlanCostEstimate visitFilter(FilterNode node, Void context)
return costForStreaming(node, localCost);
}

@Override
public PlanCostEstimate visitCteProducer(CteProducerNode node, Void context)
{
LocalCostEstimate localCost = calculateCteProducerCost(session, stats, node.getSource());
return costForStreaming(node, localCost);
}

@Override
public PlanCostEstimate visitCteConsumer(CteConsumerNode node, Void context)
{
return node.getOriginalSource().accept(this, context);
}

@Override
public PlanCostEstimate visitSequence(SequenceNode node, Void context)
{
return costForStreaming(node, LocalCostEstimate.zero());
}

@Override
public PlanCostEstimate visitProject(ProjectNode node, Void context)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import com.facebook.presto.spi.plan.IntersectNode;
import com.facebook.presto.spi.plan.JoinDistributionType;
import com.facebook.presto.spi.plan.PlanNode;
import com.facebook.presto.spi.plan.SequenceNode;
import com.facebook.presto.spi.plan.UnionNode;
import com.facebook.presto.sql.planner.iterative.GroupReference;
import com.facebook.presto.sql.planner.iterative.rule.DetermineJoinDistributionType;
Expand All @@ -34,7 +35,9 @@
import java.util.Objects;
import java.util.Optional;

import static com.facebook.presto.SystemSessionProperties.getCteProducerReplicationCoefficient;
import static com.facebook.presto.cost.LocalCostEstimate.addPartialComponents;
import static com.google.common.collect.ImmutableList.toImmutableList;
import static java.util.Objects.requireNonNull;

/**
Expand Down Expand Up @@ -166,6 +169,13 @@ public LocalCostEstimate visitUnion(UnionNode node, Void context)
return calculateRemoteGatherCost(inputSizeInBytes);
}

@Override
public LocalCostEstimate visitSequence(SequenceNode node, Void context)
{
return addPartialComponents(node.getSources().stream().map(n -> n.accept(this, context))
.collect(toImmutableList()));
}

@Override
public LocalCostEstimate visitIntersect(IntersectNode node, Void context)
{
Expand All @@ -190,6 +200,13 @@ public static LocalCostEstimate calculateRemoteRepartitionCost(double inputSizeI
return LocalCostEstimate.of(inputSizeInBytes, 0, inputSizeInBytes);
}

public static LocalCostEstimate calculateCteProducerCost(Session session, StatsProvider statsProvider, PlanNode source)
{
double inputSizeInBytes = statsProvider.getStats(source).getOutputSizeInBytes(source);
double cteProducerReplicationCoefficient = getCteProducerReplicationCoefficient(session);
return LocalCostEstimate.of(cteProducerReplicationCoefficient * inputSizeInBytes, 0, cteProducerReplicationCoefficient * inputSizeInBytes);
}

public static LocalCostEstimate calculateLocalRepartitionCost(double inputSizeInBytes)
{
return LocalCostEstimate.ofCpu(inputSizeInBytes);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.facebook.presto.cost;

import com.facebook.presto.Session;
import com.facebook.presto.matching.Pattern;
import com.facebook.presto.spi.plan.CteConsumerNode;
import com.facebook.presto.sql.planner.TypeProvider;
import com.facebook.presto.sql.planner.iterative.Lookup;

import java.util.Optional;

import static com.facebook.presto.sql.planner.plan.Patterns.cteConsumer;

public class CteConsumerStatsRule
implements ComposableStatsCalculator.Rule<CteConsumerNode>
{
private static final Pattern<CteConsumerNode> PATTERN = cteConsumer();

@Override
public Pattern<CteConsumerNode> getPattern()
{
return PATTERN;
}

@Override
public Optional<PlanNodeStatsEstimate> calculate(CteConsumerNode node, StatsProvider sourceStats, Lookup lookup, Session session, TypeProvider types)
{
return Optional.of(sourceStats.getStats(node.getOriginalSource()));
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.facebook.presto.cost;

import com.facebook.presto.Session;
import com.facebook.presto.matching.Pattern;
import com.facebook.presto.spi.plan.CteProducerNode;
import com.facebook.presto.sql.planner.TypeProvider;
import com.facebook.presto.sql.planner.iterative.Lookup;

import java.util.Optional;

import static com.facebook.presto.sql.planner.plan.Patterns.cteProducer;

public class CteProducerStatsRule
implements ComposableStatsCalculator.Rule<CteProducerNode>
{
private static final Pattern<CteProducerNode> PATTERN = cteProducer();

@Override
public Pattern<CteProducerNode> getPattern()
{
return PATTERN;
}

@Override
public Optional<PlanNodeStatsEstimate> calculate(CteProducerNode node, StatsProvider sourceStats, Lookup lookup, Session session, TypeProvider types)
{
return Optional.of(sourceStats.getStats(node.getSource()));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

import com.facebook.presto.spi.plan.PlanNode;

import java.util.List;
import java.util.stream.Stream;

import static com.google.common.base.MoreObjects.toStringHelper;
Expand Down Expand Up @@ -106,4 +107,13 @@ public static LocalCostEstimate addPartialComponents(LocalCostEstimate one, Loca
a.maxMemory + b.maxMemory,
a.networkCost + b.networkCost));
}

public static LocalCostEstimate addPartialComponents(List<LocalCostEstimate> planList)
{
return planList.stream()
.reduce(zero(), (a, b) -> new LocalCostEstimate(
a.cpuCost + b.cpuCost,
a.maxMemory + b.maxMemory,
a.networkCost + b.networkCost));
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.facebook.presto.cost;

import com.facebook.presto.Session;
import com.facebook.presto.matching.Pattern;
import com.facebook.presto.spi.plan.SequenceNode;
import com.facebook.presto.sql.planner.TypeProvider;
import com.facebook.presto.sql.planner.iterative.Lookup;

import java.util.Optional;

import static com.facebook.presto.sql.planner.plan.Patterns.sequenceNode;

public class SequenceStatsRule
implements ComposableStatsCalculator.Rule<SequenceNode>
{
private static final Pattern<SequenceNode> PATTERN = sequenceNode();

@Override
public Pattern<SequenceNode> getPattern()
{
return PATTERN;
}

@Override
public Optional<PlanNodeStatsEstimate> calculate(SequenceNode node, StatsProvider sourceStats, Lookup lookup, Session session, TypeProvider types)
{
return Optional.of(sourceStats.getStats(node.getPrimarySource()));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,9 @@ public static ComposableStatsCalculator createComposableStatsCalculator(
rules.add(new SampleStatsRule(normalizer));
rules.add(new IntersectStatsRule(normalizer));
rules.add(new RemoteSourceStatsRule(fragmentStatsProvider, normalizer));
rules.add(new SequenceStatsRule());
rules.add(new CteProducerStatsRule());
rules.add(new CteConsumerStatsRule());

return new ComposableStatsCalculator(rules.build());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,7 @@ public class FeaturesConfig
// Give a default 10% selectivity coefficient factor to avoid hitting unknown stats in join stats estimates
// which could result in syntactic join order. Set it to 0 to disable this feature
private double defaultJoinSelectivityCoefficient;
private double defaultWriterReplicationCoefficient = 3;
private boolean pushAggregationThroughJoin = true;
private double memoryRevokingTarget = 0.5;
private double memoryRevokingThreshold = 0.9;
Expand Down Expand Up @@ -1500,6 +1501,19 @@ public double getDefaultJoinSelectivityCoefficient()
return defaultJoinSelectivityCoefficient;
}

@Config("optimizer.default-writer-replication-coefficient")
@ConfigDescription("Replication coefficient for costing write operations")
public FeaturesConfig setDefaultWriterReplicationCoefficient(double defaultJoinSelectivityCoefficient)
{
this.defaultWriterReplicationCoefficient = defaultJoinSelectivityCoefficient;
return this;
}

public double getDefaultWriterReplicationCoefficient()
{
return defaultWriterReplicationCoefficient;
}

public DataSize getTopNOperatorUnspillMemoryLimit()
{
return topNOperatorUnspillMemoryLimit;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -356,8 +356,7 @@ public PlanNode visitCteConsumer(CteConsumerNode node, RewriteContext<CteContext
newConsumerColumns.add(pair.getValue());
}
});

return new CteConsumerNode(node.getSourceLocation(), node.getId(), node.getStatsEquivalentPlanNode(), newConsumerColumns, node.getCteName());
return new CteConsumerNode(node.getSourceLocation(), node.getId(), node.getStatsEquivalentPlanNode(), newConsumerColumns, node.getCteName(), node.getOriginalSource());
}

public boolean isPlanRewritten()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ public PlanNode visitCteReference(CteReferenceNode node, RewriteContext<CteTrans
node.getCteName(),
variableAllocator.newVariable("rows", BIGINT), node.getOutputVariables());
context.get().addProducer(node.getCteName(), cteProducerSource);
return new CteConsumerNode(node.getSourceLocation(), idAllocator.getNextId(), actualSource.getOutputVariables(), node.getCteName());
return new CteConsumerNode(node.getSourceLocation(), idAllocator.getNextId(), Optional.of(actualSource), actualSource.getOutputVariables(), node.getCteName(), actualSource);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
import com.facebook.presto.matching.Pattern;
import com.facebook.presto.matching.Property;
import com.facebook.presto.spi.plan.AggregationNode;
import com.facebook.presto.spi.plan.CteConsumerNode;
import com.facebook.presto.spi.plan.CteProducerNode;
import com.facebook.presto.spi.plan.FilterNode;
import com.facebook.presto.spi.plan.IntersectNode;
import com.facebook.presto.spi.plan.JoinType;
Expand All @@ -24,6 +26,7 @@
import com.facebook.presto.spi.plan.OutputNode;
import com.facebook.presto.spi.plan.PlanNode;
import com.facebook.presto.spi.plan.ProjectNode;
import com.facebook.presto.spi.plan.SequenceNode;
import com.facebook.presto.spi.plan.TableScanNode;
import com.facebook.presto.spi.plan.TopNNode;
import com.facebook.presto.spi.plan.UnionNode;
Expand Down Expand Up @@ -122,6 +125,21 @@ public static Pattern<OutputNode> output()
return typeOf(OutputNode.class);
}

public static Pattern<CteProducerNode> cteProducer()
{
return typeOf(CteProducerNode.class);
}

public static Pattern<CteConsumerNode> cteConsumer()
{
return typeOf(CteConsumerNode.class);
}

public static Pattern<SequenceNode> sequenceNode()
{
return typeOf(SequenceNode.class);
}

public static Pattern<ProjectNode> project()
{
return typeOf(ProjectNode.class);
Expand Down
Loading

0 comments on commit 946d4cd

Please sign in to comment.