Skip to content

Commit

Permalink
Add segmented aggregation support in query plan
Browse files Browse the repository at this point in the history
Enable segmented aggregation if the prefix of the sorted-by columns is a
subset of the group by column
  • Loading branch information
kewang1024 authored and rschlussel committed May 2, 2022
1 parent 8c850f7 commit 84436f5
Show file tree
Hide file tree
Showing 10 changed files with 155 additions and 37 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.facebook.presto.hive;

import com.facebook.presto.Session;
import com.facebook.presto.testing.QueryRunner;
import com.facebook.presto.tests.AbstractTestQueryFramework;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import org.testng.annotations.Test;

import java.util.Optional;

import static com.facebook.presto.SystemSessionProperties.SEGMENTED_AGGREGATION_ENABLED;
import static com.facebook.presto.hive.HiveQueryRunner.HIVE_CATALOG;
import static com.facebook.presto.hive.HiveSessionProperties.ORDER_BASED_EXECUTION_ENABLED;
import static com.facebook.presto.spi.plan.AggregationNode.Step.SINGLE;
import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.aggregation;
import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.anyTree;
import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.functionCall;
import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.singleGroupingSet;
import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.tableScan;
import static io.airlift.tpch.TpchTable.CUSTOMER;
import static io.airlift.tpch.TpchTable.LINE_ITEM;
import static io.airlift.tpch.TpchTable.NATION;
import static io.airlift.tpch.TpchTable.ORDERS;

public class TestSegmentedAggregation
extends AbstractTestQueryFramework
{
@Override
protected QueryRunner createQueryRunner()
throws Exception
{
return HiveQueryRunner.createQueryRunner(
ImmutableList.of(ORDERS, LINE_ITEM, CUSTOMER, NATION),
ImmutableMap.of("experimental.pushdown-subfields-enabled", "true"),
Optional.empty());
}

@Test
public void testAndSortedByKeysArePrefixOfGroupbyKeys()
{
QueryRunner queryRunner = getQueryRunner();

try {
queryRunner.execute("CREATE TABLE test_segmented_streaming_customer WITH ( \n" +
" bucket_count = 4, bucketed_by = ARRAY['custkey', 'name'], \n" +
" sorted_by = ARRAY['custkey', 'name'], partitioned_by=array['ds'], \n" +
" format = 'DWRF' ) AS \n" +
"SELECT *, '2021-07-11' as ds FROM customer LIMIT 1000\n");

assertPlan(
orderBasedExecutionEnabled(),
"SELECT custkey, name, nationkey, COUNT(*) FROM test_segmented_streaming_customer \n" +
"WHERE ds = '2021-07-11' GROUP BY 1, 2, 3",
anyTree(aggregation(
singleGroupingSet("custkey", "name", "nationkey"),
ImmutableMap.of(Optional.empty(), functionCall("count", ImmutableList.of())),
ImmutableList.of("custkey", "name"), // segmented streaming
ImmutableMap.of(),
Optional.empty(),
SINGLE,
tableScan("test_segmented_streaming_customer", ImmutableMap.of("custkey", "custkey", "name", "name", "nationkey", "nationkey")))));
}
finally {
queryRunner.execute("DROP TABLE IF EXISTS test_segmented_streaming_customer");
}
}

//todo:add test when Group-by Keys And prefix of Sorted-by Keys share the same elemens

private Session orderBasedExecutionEnabled()
{
return Session.builder(getQueryRunner().getDefaultSession())
.setCatalogSessionProperty(HIVE_CATALOG, ORDER_BASED_EXECUTION_ENABLED, "true")
.setSystemProperty(SEGMENTED_AGGREGATION_ENABLED, "true")
.build();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -246,38 +246,6 @@ public void testGroupbySameKeysOfSortedbyKeys()
}
}

@Test
public void testGroupbySupersetOfSortedKeys()
{
QueryRunner queryRunner = getQueryRunner();

try {
queryRunner.execute("CREATE TABLE test_customer7 WITH ( \n" +
" bucket_count = 4, bucketed_by = ARRAY['custkey'], \n" +
" sorted_by = ARRAY['custkey'], partitioned_by=array['ds'], \n" +
" format = 'DWRF' ) AS \n" +
"SELECT *, '2021-07-11' as ds FROM customer LIMIT 1000\n");

// can't enable streaming aggregation, but streaming aggregation session property would disable splittable
assertPlan(
streamingAggregationEnabled(),
"SELECT custkey, name, COUNT(*) FROM test_customer7 \n" +
"WHERE ds = '2021-07-11' GROUP BY 1, 2",
anyTree(aggregation(
singleGroupingSet("custkey", "name"),
// note: partial aggregation function has no parameter
ImmutableMap.of(Optional.empty(), functionCall("count", ImmutableList.of())),
ImmutableList.of(), // non-streaming
ImmutableMap.of(),
Optional.empty(),
SINGLE,
node(ProjectNode.class, tableScan("test_customer7", ImmutableMap.of("custkey", "custkey", "name", "name"))))));
}
finally {
queryRunner.execute("DROP TABLE IF EXISTS test_customer7");
}
}

@Test
public void testGroupbyKeysNotPrefixOfSortedKeys()
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,7 @@ public final class SystemSessionProperties
public static final String MAX_STAGE_COUNT_FOR_EAGER_SCHEDULING = "max_stage_count_for_eager_scheduling";
public static final String HYPERLOGLOG_STANDARD_ERROR_WARNING_THRESHOLD = "hyperloglog_standard_error_warning_threshold";
public static final String PREFER_MERGE_JOIN = "prefer_merge_join";
public static final String SEGMENTED_AGGREGATION_ENABLED = "segmented_aggregation_enabled";

//TODO: Prestissimo related session properties that are temporarily put here. They will be relocated in the future
public static final String PRESTISSIMO_SIMPLIFIED_EXPRESSION_EVALUATION_ENABLED = "simplified_expression_evaluation_enabled";
Expand Down Expand Up @@ -1190,6 +1191,11 @@ public SystemSessionProperties(
"To make it work, the connector needs to guarantee and expose the data properties of the underlying table.",
featuresConfig.isPreferMergeJoin(),
true),
booleanProperty(
SEGMENTED_AGGREGATION_ENABLED,
"Enable segmented aggregation.",
featuresConfig.isSegmentedAggregationEnabled(),
true),
new PropertyMetadata<>(
AGGREGATION_IF_TO_FILTER_REWRITE_STRATEGY,
format("Set the strategy used to rewrite AGG IF to AGG FILTER. Options are %s",
Expand Down Expand Up @@ -2101,6 +2107,11 @@ public static boolean preferMergeJoin(Session session)
return session.getSystemProperty(PREFER_MERGE_JOIN, Boolean.class);
}

public static boolean isSegmentedAggregationEnabled(Session session)
{
return session.getSystemProperty(SEGMENTED_AGGREGATION_ENABLED, Boolean.class);
}

public static AggregationIfToFilterRewriteStrategy getAggregationIfToFilterRewriteStrategy(Session session)
{
return session.getSystemProperty(AGGREGATION_IF_TO_FILTER_REWRITE_STRATEGY, AggregationIfToFilterRewriteStrategy.class);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,7 @@ public class FeaturesConfig

private boolean streamingForPartialAggregationEnabled;
private boolean preferMergeJoin;
private boolean segmentedAggregationEnabled;

private int maxStageCountForEagerScheduling = 25;
private boolean quickDistinctLimitEnabled;
Expand Down Expand Up @@ -2058,6 +2059,18 @@ public FeaturesConfig setPreferMergeJoin(boolean preferMergeJoin)
return this;
}

public boolean isSegmentedAggregationEnabled()
{
return segmentedAggregationEnabled;
}

@Config("optimizer.segmented-aggregation-enabled")
public FeaturesConfig setSegmentedAggregationEnabled(boolean segmentedAggregationEnabled)
{
this.segmentedAggregationEnabled = segmentedAggregationEnabled;
return this;
}

public boolean isQuickDistinctLimitEnabled()
{
return quickDistinctLimitEnabled;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,8 @@ public class PushPartialAggregationThroughJoin

private static boolean isSupportedAggregationNode(AggregationNode aggregationNode)
{
// Don't split streaming aggregations
if (aggregationNode.isStreamable()) {
// Don't split streaming aggregations or segmented aggregations
if (aggregationNode.isStreamable() || aggregationNode.isSegmentedAggregationEligible()) {
return false;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@
import static com.facebook.presto.SystemSessionProperties.isEnforceFixedDistributionForOutputOperator;
import static com.facebook.presto.SystemSessionProperties.isJoinSpillingEnabled;
import static com.facebook.presto.SystemSessionProperties.isQuickDistinctLimitEnabled;
import static com.facebook.presto.SystemSessionProperties.isSegmentedAggregationEnabled;
import static com.facebook.presto.SystemSessionProperties.isSpillEnabled;
import static com.facebook.presto.SystemSessionProperties.isTableWriterMergeOperatorEnabled;
import static com.facebook.presto.common.type.BigintType.BIGINT;
Expand Down Expand Up @@ -331,10 +332,26 @@ public PlanWithProperties visitAggregation(AggregationNode node, StreamPreferred
PlanWithProperties child = planAndEnforce(node.getSource(), childRequirements, childRequirements);

List<VariableReferenceExpression> preGroupedSymbols = ImmutableList.of();
if (!LocalProperties.match(child.getProperties().getLocalProperties(), LocalProperties.grouped(groupingKeys)).get(0).isPresent()) {
// Logic in LocalProperties.match(localProperties, groupingKeys)
// 1. Extract the longest prefix of localProperties to a set that is a subset of groupingKeys
// 2. Iterate grouped-by keys and add the elements that's not in the set to the result
// Result would be a List of one element: Optional<GroupingProperty>, GroupingProperty would contain one/multiple elements from step 2
// Eg:
// [A, B] [(B, A)] -> List.of(Optional.empty())
// [A, B] [B] -> List.of(Optional.of(GroupingProperty(B)))
// [A, B] [A] -> List.of(Optional.empty())
// [A, B] [(A, C)] -> List.of(Optional.of(GroupingProperty(C)))
// [A, B] [(D, A, C)] -> List.of(Optional.of(GroupingProperty(D, C)))
List<Optional<LocalProperty<VariableReferenceExpression>>> matchResult = LocalProperties.match(child.getProperties().getLocalProperties(), LocalProperties.grouped(groupingKeys));
if (!matchResult.get(0).isPresent()) {
// !isPresent() indicates the property was satisfied completely
preGroupedSymbols = groupingKeys;
}
else if (matchResult.get(0).get().getColumns().size() < groupingKeys.size() && isSegmentedAggregationEnabled(session)) {
// If the result size = original groupingKeys size: all grouping keys are not pre-grouped, can't enable segmented aggregation
// Otherwise: partial grouping keys are pre-grouped, can enable segmented aggregation, the result represents the grouping keys that's not pre-grouped
preGroupedSymbols = groupingKeys.stream().filter(groupingKey -> !matchResult.get(0).get().getColumns().contains(groupingKey)).collect(toImmutableList());
}

AggregationNode result = new AggregationNode(
node.getSourceLocation(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,8 @@ public PlanWithProperties visitAggregation(AggregationNode node, HashComputation
{
Optional<HashComputation> groupByHash = Optional.empty();
List<VariableReferenceExpression> groupingKeys = node.getGroupingKeys();
if (!node.isStreamable() && !canSkipHashGeneration(node.getGroupingKeys())) {
if (!node.isStreamable() && !node.isSegmentedAggregationEligible() && !canSkipHashGeneration(node.getGroupingKeys())) {
// todo: for segmented aggregation, add optimizations for the fields that need to compute hash
groupByHash = computeHash(groupingKeys, functionAndTypeManager);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -612,6 +612,9 @@ public Void visitAggregation(AggregationNode node, Void context)
if (node.getStep() != AggregationNode.Step.SINGLE) {
type = format("(%s)", node.getStep().toString());
}
if (node.isSegmentedAggregationEligible()) {
type = format("%s(SEGMENTED, %s)", type, node.getPreGroupedVariables());
}
if (node.isStreamable()) {
type = format("%s(STREAMING)", type);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,7 @@ public void testDefaults()
.setMaxStageCountForEagerScheduling(25)
.setHyperloglogStandardErrorWarningThreshold(0.004)
.setPreferMergeJoin(false)
.setSegmentedAggregationEnabled(false)
.setQueryAnalyzerTimeout(new Duration(3, MINUTES))
.setQuickDistinctLimitEnabled(false));
}
Expand Down Expand Up @@ -340,6 +341,7 @@ public void testExplicitPropertyMappings()
.put("execution-policy.max-stage-count-for-eager-scheduling", "123")
.put("hyperloglog-standard-error-warning-threshold", "0.02")
.put("optimizer.prefer-merge-join", "true")
.put("optimizer.segmented-aggregation-enabled", "true")
.put("planner.query-analyzer-timeout", "10s")
.put("optimizer.quick-distinct-limit-enabled", "true")
.build();
Expand Down Expand Up @@ -484,6 +486,7 @@ public void testExplicitPropertyMappings()
.setMaxStageCountForEagerScheduling(123)
.setHyperloglogStandardErrorWarningThreshold(0.02)
.setPreferMergeJoin(true)
.setSegmentedAggregationEnabled(true)
.setQueryAnalyzerTimeout(new Duration(10, SECONDS))
.setQuickDistinctLimitEnabled(true);
assertFullMapping(properties, expected);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,18 @@ public PlanNode replaceChildren(List<PlanNode> newChildren)

public boolean isStreamable()
{
return !preGroupedVariables.isEmpty() && groupingSets.getGroupingSetCount() == 1 && groupingSets.getGlobalGroupingSets().isEmpty();
return !preGroupedVariables.isEmpty()
&& groupingSets.getGroupingSetCount() == 1
&& groupingSets.getGlobalGroupingSets().isEmpty()
&& preGroupedVariables.size() == groupingSets.groupingKeys.size();
}

public boolean isSegmentedAggregationEligible()
{
return !preGroupedVariables.isEmpty()
&& groupingSets.getGroupingSetCount() == 1
&& groupingSets.getGlobalGroupingSets().isEmpty()
&& preGroupedVariables.size() < groupingSets.groupingKeys.size();
}

@Override
Expand Down

0 comments on commit 84436f5

Please sign in to comment.