diff --git a/presto-base-jdbc/src/main/java/com/facebook/presto/plugin/jdbc/optimization/JdbcComputePushdown.java b/presto-base-jdbc/src/main/java/com/facebook/presto/plugin/jdbc/optimization/JdbcComputePushdown.java index e0ad9a3282f38..7c3e8b0607d01 100644 --- a/presto-base-jdbc/src/main/java/com/facebook/presto/plugin/jdbc/optimization/JdbcComputePushdown.java +++ b/presto-base-jdbc/src/main/java/com/facebook/presto/plugin/jdbc/optimization/JdbcComputePushdown.java @@ -139,7 +139,8 @@ public PlanNode visitFilter(FilterNode node, RewriteContext<Void> context) oldTableScanNode.getAssignments(), oldTableScanNode.getTableConstraints(), oldTableScanNode.getCurrentConstraint(), - oldTableScanNode.getEnforcedConstraint()); + oldTableScanNode.getEnforcedConstraint(), + oldTableScanNode.getCteMaterializationInfo()); return new FilterNode(node.getSourceLocation(), idAllocator.getNextId(), newTableScanNode, node.getPredicate()); } diff --git a/presto-clickhouse/src/main/java/com/facebook/presto/plugin/clickhouse/optimization/ClickHouseComputePushdown.java b/presto-clickhouse/src/main/java/com/facebook/presto/plugin/clickhouse/optimization/ClickHouseComputePushdown.java index c025db58fa542..d6d18093cdb89 100755 --- a/presto-clickhouse/src/main/java/com/facebook/presto/plugin/clickhouse/optimization/ClickHouseComputePushdown.java +++ b/presto-clickhouse/src/main/java/com/facebook/presto/plugin/clickhouse/optimization/ClickHouseComputePushdown.java @@ -200,7 +200,8 @@ private Optional<PlanNode> tryCreatingNewScanNode(PlanNode plan) ImmutableList.copyOf(assignments.keySet()), assignments.entrySet().stream().collect(toImmutableMap(Map.Entry::getKey, (e) -> (ColumnHandle) (e.getValue()))), tableScanNode.getCurrentConstraint(), - tableScanNode.getEnforcedConstraint())); + tableScanNode.getEnforcedConstraint(), + tableScanNode.getCteMaterializationInfo())); } @Override @@ -288,7 +289,8 @@ public PlanNode visitFilter(FilterNode node, Void context) oldTableScanNode.getOutputVariables(), oldTableScanNode.getAssignments(), oldTableScanNode.getCurrentConstraint(), - oldTableScanNode.getEnforcedConstraint()); + oldTableScanNode.getEnforcedConstraint(), + oldTableScanNode.getCteMaterializationInfo()); return new FilterNode(node.getSourceLocation(), idAllocator.getNextId(), newTableScanNode, node.getPredicate()); } diff --git a/presto-docs/src/main/sphinx/admin/cte-materialization.rst b/presto-docs/src/main/sphinx/admin/cte-materialization.rst index 332bc02cb2891..9b169f02cee2c 100644 --- a/presto-docs/src/main/sphinx/admin/cte-materialization.rst +++ b/presto-docs/src/main/sphinx/admin/cte-materialization.rst @@ -118,7 +118,6 @@ This setting specifies the Hash function type for CTE materialization. Use the ``hive.bucket_function_type_for_cte_materialization`` session property to set on a per-query basis. - ``query.max-written-intermediate-bytes`` ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ @@ -129,6 +128,17 @@ This setting defines a cap on the amount of data that can be written during CTE Use the ``query_max_written_intermediate_bytes`` session property to set on a per-query basis. +``enhanced-cte-scheduling-enabled`` +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + + * **Type:** ``boolean`` + * **Default value:** ``true`` + +Flag to enable or disable the enhanced-cte-blocking during CTE Materialization. Enhanced CTE blocking restricts only the table scan stages of the CTE TableScan, rather than blocking entire plan sections, including the main query, until the query completes. +This approach can improve latency in scenarios where parts of the query can execute concurrently with CTE materialization writes. + +Use the ``enhanced_cte_scheduling_enabled`` session property to set on a per-query basis. + How to Participate in Development --------------------------------- diff --git a/presto-druid/src/main/java/com/facebook/presto/druid/DruidPlanOptimizer.java b/presto-druid/src/main/java/com/facebook/presto/druid/DruidPlanOptimizer.java index 81401f78c8c7b..0259bbc8c2234 100644 --- a/presto-druid/src/main/java/com/facebook/presto/druid/DruidPlanOptimizer.java +++ b/presto-druid/src/main/java/com/facebook/presto/druid/DruidPlanOptimizer.java @@ -173,7 +173,8 @@ private Optional<PlanNode> tryCreatingNewScanNode(PlanNode plan) assignments.entrySet().stream().collect(toImmutableMap(Map.Entry::getKey, (e) -> (ColumnHandle) (e.getValue()))), tableScanNode.getTableConstraints(), tableScanNode.getCurrentConstraint(), - tableScanNode.getEnforcedConstraint())); + tableScanNode.getEnforcedConstraint(), + tableScanNode.getCteMaterializationInfo())); } @Override diff --git a/presto-druid/src/test/java/com/facebook/presto/druid/TestDruidQueryBase.java b/presto-druid/src/test/java/com/facebook/presto/druid/TestDruidQueryBase.java index 428d6d7bf300d..912e970b709f5 100644 --- a/presto-druid/src/test/java/com/facebook/presto/druid/TestDruidQueryBase.java +++ b/presto-druid/src/test/java/com/facebook/presto/druid/TestDruidQueryBase.java @@ -147,7 +147,7 @@ protected TableScanNode tableScan(PlanBuilder planBuilder, DruidTableHandle conn variables, assignments.build(), TupleDomain.all(), - TupleDomain.all()); + TupleDomain.all(), Optional.empty()); } protected FilterNode filter(PlanBuilder planBuilder, PlanNode source, RowExpression predicate) diff --git a/presto-hive-common/src/main/java/com/facebook/presto/hive/rule/BaseSubfieldExtractionRewriter.java b/presto-hive-common/src/main/java/com/facebook/presto/hive/rule/BaseSubfieldExtractionRewriter.java index add6ddb6b7a60..070b940f3db9f 100644 --- a/presto-hive-common/src/main/java/com/facebook/presto/hive/rule/BaseSubfieldExtractionRewriter.java +++ b/presto-hive-common/src/main/java/com/facebook/presto/hive/rule/BaseSubfieldExtractionRewriter.java @@ -307,7 +307,8 @@ private static TableScanNode getTableScanNode( tableScan.getAssignments(), tableScan.getTableConstraints(), pushdownFilterResult.getLayout().getPredicate(), - TupleDomain.all()); + TupleDomain.all(), + tableScan.getCteMaterializationInfo()); } private static ExtractionResult intersectExtractionResult( diff --git a/presto-hive/src/main/java/com/facebook/presto/hive/rule/HiveAddRequestedColumnsToLayout.java b/presto-hive/src/main/java/com/facebook/presto/hive/rule/HiveAddRequestedColumnsToLayout.java index 2e159874a582f..36b7966641699 100644 --- a/presto-hive/src/main/java/com/facebook/presto/hive/rule/HiveAddRequestedColumnsToLayout.java +++ b/presto-hive/src/main/java/com/facebook/presto/hive/rule/HiveAddRequestedColumnsToLayout.java @@ -71,7 +71,8 @@ public PlanNode visitTableScan(TableScanNode tableScan, RewriteContext<Void> con tableScan.getAssignments(), tableScan.getTableConstraints(), tableScan.getCurrentConstraint(), - tableScan.getEnforcedConstraint()); + tableScan.getEnforcedConstraint(), + tableScan.getCteMaterializationInfo()); } } } diff --git a/presto-hive/src/main/java/com/facebook/presto/hive/rule/HivePartialAggregationPushdown.java b/presto-hive/src/main/java/com/facebook/presto/hive/rule/HivePartialAggregationPushdown.java index d5efb222fd16a..cf1ff92e6d50a 100644 --- a/presto-hive/src/main/java/com/facebook/presto/hive/rule/HivePartialAggregationPushdown.java +++ b/presto-hive/src/main/java/com/facebook/presto/hive/rule/HivePartialAggregationPushdown.java @@ -297,7 +297,8 @@ private Optional<PlanNode> tryPartialAggregationPushdown(PlanNode plan) ImmutableMap.copyOf(assignments), oldTableScanNode.getTableConstraints(), oldTableScanNode.getCurrentConstraint(), - oldTableScanNode.getEnforcedConstraint())); + oldTableScanNode.getEnforcedConstraint(), + oldTableScanNode.getCteMaterializationInfo())); } @Override diff --git a/presto-hive/src/test/java/com/facebook/presto/hive/TestCteExecution.java b/presto-hive/src/test/java/com/facebook/presto/hive/TestCteExecution.java index c86b91f206f10..1e4f13b963905 100644 --- a/presto-hive/src/test/java/com/facebook/presto/hive/TestCteExecution.java +++ b/presto-hive/src/test/java/com/facebook/presto/hive/TestCteExecution.java @@ -1171,6 +1171,15 @@ public void testChainedCteProjectionAndFilterPushDown() generateMaterializedCTEInformation("cte5", 1, false, true))); } + @Test + public void testCTEMaterializationWithEnhancedScheduling() + { + QueryRunner queryRunner = getQueryRunner(); + String sql = "WITH temp as (SELECT orderkey FROM ORDERS) " + + "SELECT * FROM temp t1 JOIN (SELECT custkey FROM customer) c ON t1.orderkey=c.custkey"; + verifyResults(queryRunner, sql, ImmutableList.of(generateMaterializedCTEInformation("temp", 1, false, true))); + } + @Test public void testWrittenIntemediateByteLimit() throws Exception diff --git a/presto-iceberg/src/main/java/com/facebook/presto/iceberg/optimizer/IcebergEqualityDeleteAsJoin.java b/presto-iceberg/src/main/java/com/facebook/presto/iceberg/optimizer/IcebergEqualityDeleteAsJoin.java index f54d2e78b1282..401b298ed079b 100644 --- a/presto-iceberg/src/main/java/com/facebook/presto/iceberg/optimizer/IcebergEqualityDeleteAsJoin.java +++ b/presto-iceberg/src/main/java/com/facebook/presto/iceberg/optimizer/IcebergEqualityDeleteAsJoin.java @@ -341,7 +341,8 @@ private TableScanNode createDeletesTableScan(ImmutableMap<VariableReferenceExpre outputs, deleteColumnAssignments, TupleDomain.all(), - TupleDomain.all()); + TupleDomain.all(), + Optional.empty()); } /** @@ -382,7 +383,8 @@ private TableScanNode createNewRoot(TableScanNode node, IcebergTableHandle icebe assignmentsBuilder.build(), node.getTableConstraints(), node.getCurrentConstraint(), - node.getEnforcedConstraint()); + node.getEnforcedConstraint(), + node.getCteMaterializationInfo()); } /** diff --git a/presto-iceberg/src/main/java/com/facebook/presto/iceberg/optimizer/IcebergPlanOptimizer.java b/presto-iceberg/src/main/java/com/facebook/presto/iceberg/optimizer/IcebergPlanOptimizer.java index 35f8c6b356bb7..dac79f9adfafc 100644 --- a/presto-iceberg/src/main/java/com/facebook/presto/iceberg/optimizer/IcebergPlanOptimizer.java +++ b/presto-iceberg/src/main/java/com/facebook/presto/iceberg/optimizer/IcebergPlanOptimizer.java @@ -238,7 +238,8 @@ public PlanNode visitFilter(FilterNode filter, RewriteContext<Void> context) .intersect(tableScan.getCurrentConstraint()), predicateNotChangedBySimplification ? identityPartitionColumnPredicate.intersect(tableScan.getEnforcedConstraint()) : - tableScan.getEnforcedConstraint()); + tableScan.getEnforcedConstraint(), + tableScan.getCteMaterializationInfo()); if (TRUE_CONSTANT.equals(remainingFilterExpression) && predicateNotChangedBySimplification) { return newTableScan; diff --git a/presto-main/src/main/java/com/facebook/presto/SystemSessionProperties.java b/presto-main/src/main/java/com/facebook/presto/SystemSessionProperties.java index 7f23e8172e800..93a485574f1a9 100644 --- a/presto-main/src/main/java/com/facebook/presto/SystemSessionProperties.java +++ b/presto-main/src/main/java/com/facebook/presto/SystemSessionProperties.java @@ -234,6 +234,7 @@ public final class SystemSessionProperties public static final String QUERY_RETRY_MAX_EXECUTION_TIME = "query_retry_max_execution_time"; public static final String PARTIAL_RESULTS_ENABLED = "partial_results_enabled"; public static final String PARTIAL_RESULTS_COMPLETION_RATIO_THRESHOLD = "partial_results_completion_ratio_threshold"; + public static final String ENHANCED_CTE_SCHEDULING_ENABLED = "enhanced-cte-scheduling-enabled"; public static final String PARTIAL_RESULTS_MAX_EXECUTION_TIME_MULTIPLIER = "partial_results_max_execution_time_multiplier"; public static final String OFFSET_CLAUSE_ENABLED = "offset_clause_enabled"; public static final String VERBOSE_EXCEEDED_MEMORY_LIMIT_ERRORS_ENABLED = "verbose_exceeded_memory_limit_errors_enabled"; @@ -1282,6 +1283,11 @@ public SystemSessionProperties( "Minimum query completion ratio threshold for partial results", featuresConfig.getPartialResultsCompletionRatioThreshold(), false), + booleanProperty( + ENHANCED_CTE_SCHEDULING_ENABLED, + "Applicable for CTE Materialization. If enabled, only tablescans of the pending tablewriters are blocked and other stages can continue.", + featuresConfig.getEnhancedCTESchedulingEnabled(), + true), booleanProperty( OFFSET_CLAUSE_ENABLED, "Enable support for OFFSET clause", @@ -2690,6 +2696,11 @@ public static double getPartialResultsCompletionRatioThreshold(Session session) return session.getSystemProperty(PARTIAL_RESULTS_COMPLETION_RATIO_THRESHOLD, Double.class); } + public static boolean isEnhancedCTESchedulingEnabled(Session session) + { + return isCteMaterializationApplicable(session) & session.getSystemProperty(ENHANCED_CTE_SCHEDULING_ENABLED, Boolean.class); + } + public static double getPartialResultsMaxExecutionTimeMultiplier(Session session) { return session.getSystemProperty(PARTIAL_RESULTS_MAX_EXECUTION_TIME_MULTIPLIER, Double.class); diff --git a/presto-main/src/main/java/com/facebook/presto/execution/SqlStageExecution.java b/presto-main/src/main/java/com/facebook/presto/execution/SqlStageExecution.java index ee2159eba56c7..2c6b7e8b60dea 100644 --- a/presto-main/src/main/java/com/facebook/presto/execution/SqlStageExecution.java +++ b/presto-main/src/main/java/com/facebook/presto/execution/SqlStageExecution.java @@ -26,10 +26,15 @@ import com.facebook.presto.metadata.Split; import com.facebook.presto.server.remotetask.HttpRemoteTask; import com.facebook.presto.spi.PrestoException; +import com.facebook.presto.spi.plan.CteMaterializationInfo; import com.facebook.presto.spi.plan.PlanFragmentId; +import com.facebook.presto.spi.plan.PlanNode; import com.facebook.presto.spi.plan.PlanNodeId; +import com.facebook.presto.spi.plan.TableFinishNode; +import com.facebook.presto.spi.plan.TableScanNode; import com.facebook.presto.split.RemoteSplit; import com.facebook.presto.sql.planner.PlanFragment; +import com.facebook.presto.sql.planner.optimizations.PlanNodeSearcher; import com.facebook.presto.sql.planner.plan.RemoteSourceNode; import com.google.common.collect.HashMultimap; import com.google.common.collect.ImmutableMap; @@ -60,8 +65,10 @@ import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicReference; import java.util.function.Consumer; +import java.util.stream.Collectors; import static com.facebook.presto.SystemSessionProperties.getMaxFailedTaskPercentage; +import static com.facebook.presto.SystemSessionProperties.isEnhancedCTESchedulingEnabled; import static com.facebook.presto.failureDetector.FailureDetector.State.GONE; import static com.facebook.presto.operator.ExchangeOperator.REMOTE_CONNECTOR_ID; import static com.facebook.presto.spi.StandardErrorCode.GENERIC_INTERNAL_ERROR; @@ -557,7 +564,6 @@ private synchronized RemoteTask scheduleTask(InternalNode node, TaskId taskId, M // stage finished while we were scheduling this task task.abort(); } - return task; } @@ -594,6 +600,59 @@ private static Split createRemoteSplitFor(TaskId taskId, URI remoteSourceTaskLoc return new Split(REMOTE_CONNECTOR_ID, new RemoteTransactionHandle(), new RemoteSplit(new Location(splitLocation), remoteSourceTaskId)); } + private static String getCteIdFromSource(PlanNode source) + { + // Traverse the plan node tree to find a TableWriterNode with TemporaryTableInfo + return PlanNodeSearcher.searchFrom(source) + .where(planNode -> planNode instanceof TableFinishNode) + .findFirst() + .flatMap(planNode -> ((TableFinishNode) planNode).getCteMaterializationInfo()) + .map(CteMaterializationInfo::getCteId) + .orElseThrow(() -> new IllegalStateException("TemporaryTableInfo has no CTE ID")); + } + + public boolean isCTETableFinishStage() + { + return PlanNodeSearcher.searchFrom(planFragment.getRoot()) + .where(planNode -> planNode instanceof TableFinishNode && + ((TableFinishNode) planNode).getCteMaterializationInfo().isPresent()) + .findSingle() + .isPresent(); + } + + public String getCTEWriterId() + { + // Validate that this is a CTE TableFinish stage and return the associated CTE ID + if (!isCTETableFinishStage()) { + throw new IllegalStateException("This stage is not a CTE writer stage"); + } + return getCteIdFromSource(planFragment.getRoot()); + } + + public boolean requiresMaterializedCTE() + { + if (!isEnhancedCTESchedulingEnabled(session)) { + return false; + } + // Search for TableScanNodes and check if they reference TemporaryTableInfo + return PlanNodeSearcher.searchFrom(planFragment.getRoot()) + .where(planNode -> planNode instanceof TableScanNode) + .findAll().stream() + .anyMatch(planNode -> ((TableScanNode) planNode).getCteMaterializationInfo().isPresent()); + } + + public List<String> getRequiredCTEList() + { + // Collect all CTE IDs referenced by TableScanNodes with TemporaryTableInfo + return PlanNodeSearcher.searchFrom(planFragment.getRoot()) + .where(planNode -> planNode instanceof TableScanNode) + .findAll().stream() + .map(planNode -> ((TableScanNode) planNode).getCteMaterializationInfo() + .orElseThrow(() -> new IllegalStateException("TableScanNode has no TemporaryTableInfo"))) + .map(CteMaterializationInfo::getCteId) + .collect(Collectors.toList()); + } + private void updateTaskStatus(TaskId taskId, TaskStatus taskStatus) { StageExecutionState stageExecutionState = getState(); diff --git a/presto-main/src/main/java/com/facebook/presto/execution/scheduler/CTEMaterializationTracker.java b/presto-main/src/main/java/com/facebook/presto/execution/scheduler/CTEMaterializationTracker.java new file mode 100644 index 0000000000000..a82ef6508d954 --- /dev/null +++ b/presto-main/src/main/java/com/facebook/presto/execution/scheduler/CTEMaterializationTracker.java @@ -0,0 +1,63 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.execution.scheduler; + +import com.google.common.base.Preconditions; +import com.google.common.util.concurrent.Futures; +import com.google.common.util.concurrent.ListenableFuture; +import com.google.common.util.concurrent.SettableFuture; + +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; + +/* + * Tracks the completion status of table-finish nodes that write temporary tables for CTE materialization. + * CTEMaterializationTracker manages a map of materialized CTEs and their associated materialization futures. + * When a stage includes a CTE table finish, it marks the corresponding CTE as materialized and completes + * the associated future. + * This signals the scheduler that some dependency has been resolved, prompting it to resume/continue scheduling. + */ +public class CTEMaterializationTracker +{ + private final Map<String, SettableFuture<Void>> materializationFutures = new ConcurrentHashMap<>(); + + public ListenableFuture<Void> getFutureForCTE(String cteName) + { + return Futures.nonCancellationPropagating( + materializationFutures.compute(cteName, (key, existingFuture) -> { + if (existingFuture == null) { + // Create a new SettableFuture and store it internally + return SettableFuture.create(); + } + Preconditions.checkArgument(!existingFuture.isCancelled(), + String.format("Error: Existing future was found cancelled in CTEMaterializationTracker for cte", cteName)); + return existingFuture; + })); + } + + public void markCTEAsMaterialized(String cteName) + { + materializationFutures.compute(cteName, (key, existingFuture) -> { + if (existingFuture == null) { + SettableFuture<Void> completedFuture = SettableFuture.create(); + completedFuture.set(null); + return completedFuture; + } + Preconditions.checkArgument(!existingFuture.isCancelled(), + String.format("Error: Existing future was found cancelled in CTEMaterializationTracker for cte", cteName)); + existingFuture.set(null); // Notify all listeners + return existingFuture; + }); + } +} diff --git a/presto-main/src/main/java/com/facebook/presto/execution/scheduler/FixedSourcePartitionedScheduler.java b/presto-main/src/main/java/com/facebook/presto/execution/scheduler/FixedSourcePartitionedScheduler.java index 8e965aab792c3..9febc345114f2 100644 --- a/presto-main/src/main/java/com/facebook/presto/execution/scheduler/FixedSourcePartitionedScheduler.java +++ b/presto-main/src/main/java/com/facebook/presto/execution/scheduler/FixedSourcePartitionedScheduler.java @@ -74,6 +74,8 @@ public class FixedSourcePartitionedScheduler private final Queue<Integer> tasksToRecover = new ConcurrentLinkedQueue<>(); + private final CTEMaterializationTracker cteMaterializationTracker; + @GuardedBy("this") private boolean closed; @@ -87,13 +89,15 @@ public FixedSourcePartitionedScheduler( int splitBatchSize, OptionalInt concurrentLifespansPerTask, NodeSelector nodeSelector, - List<ConnectorPartitionHandle> partitionHandles) + List<ConnectorPartitionHandle> partitionHandles, + CTEMaterializationTracker cteMaterializationTracker) { requireNonNull(stage, "stage is null"); requireNonNull(splitSources, "splitSources is null"); requireNonNull(bucketNodeMap, "bucketNodeMap is null"); checkArgument(!requireNonNull(nodes, "nodes is null").isEmpty(), "nodes is empty"); requireNonNull(partitionHandles, "partitionHandles is null"); + this.cteMaterializationTracker = cteMaterializationTracker; this.stage = stage; this.nodes = ImmutableList.copyOf(nodes); @@ -179,6 +183,29 @@ public ScheduleResult schedule() { // schedule a task on every node in the distribution List<RemoteTask> newTasks = ImmutableList.of(); + + // CTE Materialization Check + if (stage.requiresMaterializedCTE()) { + List<ListenableFuture<?>> blocked = new ArrayList<>(); + List<String> requiredCTEIds = stage.getRequiredCTEList(); + for (String cteId : requiredCTEIds) { + ListenableFuture<Void> cteFuture = cteMaterializationTracker.getFutureForCTE(cteId); + if (!cteFuture.isDone()) { + // Add CTE materialization future to the blocked list + blocked.add(cteFuture); + } + } + // If any CTE is not materialized, return a blocked ScheduleResult + if (!blocked.isEmpty()) { + return ScheduleResult.blocked( + false, + newTasks, + whenAnyComplete(blocked), + BlockedReason.WAITING_FOR_CTE_MATERIALIZATION, + 0); + } + } + // schedule a task on every node in the distribution if (!scheduledTasks) { newTasks = Streams.mapWithIndex( nodes.stream(), @@ -191,9 +218,8 @@ public ScheduleResult schedule() // notify listeners that we have scheduled all tasks so they can set no more buffers or exchange splits stage.transitionToFinishedTaskScheduling(); } - - boolean allBlocked = true; List<ListenableFuture<?>> blocked = new ArrayList<>(); + boolean allBlocked = true; BlockedReason blockedReason = BlockedReason.NO_ACTIVE_DRIVER_GROUP; if (groupedLifespanScheduler.isPresent()) { diff --git a/presto-main/src/main/java/com/facebook/presto/execution/scheduler/ScheduleResult.java b/presto-main/src/main/java/com/facebook/presto/execution/scheduler/ScheduleResult.java index ed85bfff8fd94..dfc6288f78617 100644 --- a/presto-main/src/main/java/com/facebook/presto/execution/scheduler/ScheduleResult.java +++ b/presto-main/src/main/java/com/facebook/presto/execution/scheduler/ScheduleResult.java @@ -57,6 +57,11 @@ public enum BlockedReason * grouped execution where there are multiple lifespans per task). */ MIXED_SPLIT_QUEUES_FULL_AND_WAITING_FOR_SOURCE, + + /** + * Waiting for the completion of CTE materialization by the table writer. + */ + WAITING_FOR_CTE_MATERIALIZATION, /**/; public BlockedReason combineWith(BlockedReason other) @@ -64,6 +69,7 @@ public BlockedReason combineWith(BlockedReason other) switch (this) { case WRITER_SCALING: throw new IllegalArgumentException("cannot be combined"); + case WAITING_FOR_CTE_MATERIALIZATION: case NO_ACTIVE_DRIVER_GROUP: return other; case SPLIT_QUEUES_FULL: diff --git a/presto-main/src/main/java/com/facebook/presto/execution/scheduler/SectionExecutionFactory.java b/presto-main/src/main/java/com/facebook/presto/execution/scheduler/SectionExecutionFactory.java index cba5736650fdb..26b77ef163ca5 100644 --- a/presto-main/src/main/java/com/facebook/presto/execution/scheduler/SectionExecutionFactory.java +++ b/presto-main/src/main/java/com/facebook/presto/execution/scheduler/SectionExecutionFactory.java @@ -170,7 +170,8 @@ public SectionExecution createSectionExecutions( boolean summarizeTaskInfo, RemoteTaskFactory remoteTaskFactory, SplitSourceFactory splitSourceFactory, - int attemptId) + int attemptId, + CTEMaterializationTracker cteMaterializationTracker) { // Only fetch a distribution once per section to ensure all stages see the same machine assignments Map<PartitioningHandle, NodePartitionMap> partitioningCache = new HashMap<>(); @@ -186,7 +187,8 @@ public SectionExecution createSectionExecutions( summarizeTaskInfo, remoteTaskFactory, splitSourceFactory, - attemptId); + attemptId, + cteMaterializationTracker); StageExecutionAndScheduler rootStage = getLast(sectionStages); rootStage.getStageExecution().setOutputBuffers(outputBuffers); return new SectionExecution(rootStage, sectionStages); @@ -205,7 +207,8 @@ private List<StageExecutionAndScheduler> createStreamingLinkedStageExecutions( boolean summarizeTaskInfo, RemoteTaskFactory remoteTaskFactory, SplitSourceFactory splitSourceFactory, - int attemptId) + int attemptId, + CTEMaterializationTracker cteMaterializationTracker) { ImmutableList.Builder<StageExecutionAndScheduler> stageExecutionAndSchedulers = ImmutableList.builder(); @@ -240,7 +243,8 @@ private List<StageExecutionAndScheduler> createStreamingLinkedStageExecutions( summarizeTaskInfo, remoteTaskFactory, splitSourceFactory, - attemptId); + attemptId, + cteMaterializationTracker); stageExecutionAndSchedulers.addAll(subTree); childStagesBuilder.add(getLast(subTree).getStageExecution()); } @@ -262,7 +266,8 @@ private List<StageExecutionAndScheduler> createStreamingLinkedStageExecutions( stageExecution, partitioningHandle, tableWriteInfo, - childStageExecutions); + childStageExecutions, + cteMaterializationTracker); stageExecutionAndSchedulers.add(new StageExecutionAndScheduler( stageExecution, stageLinkage, @@ -281,7 +286,8 @@ private StageScheduler createStageScheduler( SqlStageExecution stageExecution, PartitioningHandle partitioningHandle, TableWriteInfo tableWriteInfo, - Set<SqlStageExecution> childStageExecutions) + Set<SqlStageExecution> childStageExecutions, + CTEMaterializationTracker cteMaterializationTracker) { Map<PlanNodeId, SplitSource> splitSources = splitSourceFactory.createSplitSources(plan.getFragment(), session, tableWriteInfo); int maxTasksPerStage = getMaxTasksPerStage(session); @@ -341,7 +347,8 @@ else if (partitioningHandle.equals(SCALED_WRITER_DISTRIBUTION)) { splitBatchSize, getConcurrentLifespansPerNode(session), nodeSelector, - ImmutableList.of(NOT_PARTITIONED)); + ImmutableList.of(NOT_PARTITIONED), + cteMaterializationTracker); } else if (!splitSources.isEmpty()) { // contains local source @@ -400,7 +407,8 @@ else if (!splitSources.isEmpty()) { splitBatchSize, getConcurrentLifespansPerNode(session), nodeScheduler.createNodeSelector(session, connectorId, nodePredicate), - connectorPartitionHandles); + connectorPartitionHandles, + cteMaterializationTracker); if (plan.getFragment().getStageExecutionDescriptor().isRecoverableGroupedExecution()) { stageExecution.registerStageTaskRecoveryCallback(taskId -> { checkArgument(taskId.getStageExecutionId().getStageId().equals(stageId), "The task did not execute this stage"); diff --git a/presto-main/src/main/java/com/facebook/presto/execution/scheduler/SplitSchedulerStats.java b/presto-main/src/main/java/com/facebook/presto/execution/scheduler/SplitSchedulerStats.java index 6d0507082e8bb..83b4bbaa8c69c 100644 --- a/presto-main/src/main/java/com/facebook/presto/execution/scheduler/SplitSchedulerStats.java +++ b/presto-main/src/main/java/com/facebook/presto/execution/scheduler/SplitSchedulerStats.java @@ -32,6 +32,8 @@ public class SplitSchedulerStats private final CounterStat splitQueuesFull = new CounterStat(); private final CounterStat mixedSplitQueuesFullAndWaitingForSource = new CounterStat(); private final CounterStat noActiveDriverGroup = new CounterStat(); + + private final CounterStat waitingForCTEMaterialization = new CounterStat(); private final DistributionStat splitsPerIteration = new DistributionStat(); @Managed @@ -62,6 +64,13 @@ public CounterStat getWaitingForSource() return waitingForSource; } + @Managed + @Nested + public CounterStat getWaitingForCTEMaterialization() + { + return waitingForCTEMaterialization; + } + @Managed @Nested public CounterStat getSplitQueuesFull() diff --git a/presto-main/src/main/java/com/facebook/presto/execution/scheduler/SqlQueryScheduler.java b/presto-main/src/main/java/com/facebook/presto/execution/scheduler/SqlQueryScheduler.java index 198f23f1ba22e..a255a939f6ffa 100644 --- a/presto-main/src/main/java/com/facebook/presto/execution/scheduler/SqlQueryScheduler.java +++ b/presto-main/src/main/java/com/facebook/presto/execution/scheduler/SqlQueryScheduler.java @@ -79,6 +79,7 @@ import static com.facebook.presto.SystemSessionProperties.getMaxConcurrentMaterializations; import static com.facebook.presto.SystemSessionProperties.getPartialResultsCompletionRatioThreshold; import static com.facebook.presto.SystemSessionProperties.getPartialResultsMaxExecutionTimeMultiplier; +import static com.facebook.presto.SystemSessionProperties.isEnhancedCTESchedulingEnabled; import static com.facebook.presto.SystemSessionProperties.isPartialResultsEnabled; import static com.facebook.presto.SystemSessionProperties.isRuntimeOptimizerEnabled; import static com.facebook.presto.execution.BasicStageExecutionStats.aggregateBasicStageStats; @@ -149,6 +150,7 @@ public class SqlQueryScheduler private final AtomicBoolean scheduling = new AtomicBoolean(); private final PartialResultQueryTaskTracker partialResultQueryTaskTracker; + private final CTEMaterializationTracker cteMaterializationTracker = new CTEMaterializationTracker(); public static SqlQueryScheduler createSqlQueryScheduler( LocationFactory locationFactory, @@ -278,6 +280,17 @@ else if (state == CANCELED) { for (StageExecutionAndScheduler stageExecutionInfo : stageExecutions.values()) { SqlStageExecution stageExecution = stageExecutionInfo.getStageExecution(); + // Add a listener for state changes + if (stageExecution.isCTETableFinishStage()) { + stageExecution.addStateChangeListener(state -> { + if (state == StageExecutionState.FINISHED) { + String cteName = stageExecution.getCTEWriterId(); + log.debug("CTE write completed for: " + cteName); + // Notify the materialization tracker + cteMaterializationTracker.markCTEAsMaterialized(cteName); + } + }); + } stageExecution.addStateChangeListener(state -> { if (queryStateMachine.isDone()) { return; @@ -363,7 +376,8 @@ private List<StageExecutionAndScheduler> createStageExecutions( summarizeTaskInfo, remoteTaskFactory, splitSourceFactory, - 0).getSectionStages(); + 0, + cteMaterializationTracker).getSectionStages(); stages.addAll(sectionStages); return stages.build(); @@ -460,7 +474,9 @@ else if (!result.getBlocked().isDone()) { ScheduleResult.BlockedReason blockedReason = result.getBlockedReason().get(); switch (blockedReason) { case WRITER_SCALING: - // no-op + break; + case WAITING_FOR_CTE_MATERIALIZATION: + schedulerStats.getWaitingForCTEMaterialization().update(1); break; case WAITING_FOR_SOURCE: schedulerStats.getWaitingForSource().update(1); @@ -568,10 +584,12 @@ private List<StreamingPlanSection> getSectionsReadyForExecution() .map(section -> getStageExecution(section.getPlan().getFragment().getId()).getState()) .filter(state -> !state.isDone() && state != PLANNED) .count(); + return stream(forTree(StreamingPlanSection::getChildren).depthFirstPreOrder(sectionedPlan)) // get all sections ready for execution .filter(this::isReadyForExecution) - .limit(maxConcurrentMaterializations - runningPlanSections) + // for enhanced cte blocking we do not need a limit on the sections + .limit(isEnhancedCTESchedulingEnabled(session) ? Long.MAX_VALUE : maxConcurrentMaterializations - runningPlanSections) .map(this::tryCostBasedOptimize) .collect(toImmutableList()); } @@ -678,7 +696,8 @@ private void updateStageExecutions(StreamingPlanSection section, Map<PlanFragmen summarizeTaskInfo, remoteTaskFactory, splitSourceFactory, - 0); + 0, + cteMaterializationTracker); addStateChangeListeners(sectionExecution); Map<StageId, StageExecutionAndScheduler> updatedStageExecutions = sectionExecution.getSectionStages().stream() .collect(toImmutableMap(execution -> execution.getStageExecution().getStageExecutionId().getStageId(), identity())); @@ -774,10 +793,13 @@ private boolean isReadyForExecution(StreamingPlanSection section) // already scheduled return false; } - for (StreamingPlanSection child : section.getChildren()) { - SqlStageExecution rootStageExecution = getStageExecution(child.getPlan().getFragment().getId()); - if (rootStageExecution.getState() != FINISHED) { - return false; + if (!isEnhancedCTESchedulingEnabled(session)) { + // Enhanced cte blocking is not enabled so block till child sections are complete + for (StreamingPlanSection child : section.getChildren()) { + SqlStageExecution rootStageExecution = getStageExecution(child.getPlan().getFragment().getId()); + if (rootStageExecution.getState() != FINISHED) { + return false; + } } } return true; diff --git a/presto-main/src/main/java/com/facebook/presto/sql/TemporaryTableUtil.java b/presto-main/src/main/java/com/facebook/presto/sql/TemporaryTableUtil.java index d9c68ced227b7..7fb1b16b3b780 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/TemporaryTableUtil.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/TemporaryTableUtil.java @@ -35,6 +35,7 @@ import com.facebook.presto.spi.function.FunctionHandle; import com.facebook.presto.spi.plan.AggregationNode; import com.facebook.presto.spi.plan.Assignments; +import com.facebook.presto.spi.plan.CteMaterializationInfo; import com.facebook.presto.spi.plan.Partitioning; import com.facebook.presto.spi.plan.PartitioningHandle; import com.facebook.presto.spi.plan.PartitioningScheme; @@ -82,6 +83,7 @@ import static com.google.common.collect.ImmutableSet.toImmutableSet; import static com.google.common.collect.Iterables.concat; import static java.lang.String.format; +import static java.util.Collections.emptyList; import static java.util.function.Function.identity; // Planner Util for creating temporary tables @@ -99,7 +101,8 @@ public static TableScanNode createTemporaryTableScan( TableHandle tableHandle, List<VariableReferenceExpression> outputVariables, Map<VariableReferenceExpression, ColumnMetadata> variableToColumnMap, - Optional<PartitioningMetadata> expectedPartitioningMetadata) + Optional<PartitioningMetadata> expectedPartitioningMetadata, + Optional<String> cteId) { Map<String, ColumnHandle> columnHandles = metadata.getColumnHandles(session, tableHandle); Map<VariableReferenceExpression, ColumnMetadata> outputColumns = outputVariables.stream() @@ -126,11 +129,14 @@ public static TableScanNode createTemporaryTableScan( return new TableScanNode( sourceLocation, idAllocator.getNextId(), + Optional.empty(), selectedLayout.getLayout().getNewTableHandle(), outputVariables, assignments, + emptyList(), + TupleDomain.all(), TupleDomain.all(), - TupleDomain.all()); + cteId.map(CteMaterializationInfo::new)); } public static Map<VariableReferenceExpression, ColumnMetadata> assignTemporaryTableColumnNames(Collection<VariableReferenceExpression> outputVariables, @@ -181,7 +187,8 @@ public static TableFinishNode createTemporaryTableWriteWithoutExchanges( TableHandle tableHandle, List<VariableReferenceExpression> outputs, Map<VariableReferenceExpression, ColumnMetadata> variableToColumnMap, - VariableReferenceExpression outputVar) + VariableReferenceExpression outputVar, + Optional<String> cteId) { SchemaTableName schemaTableName = metadata.getTableMetadata(session, tableHandle).getTable(); TableWriterNode.InsertReference insertReference = new TableWriterNode.InsertReference(tableHandle, schemaTableName); @@ -215,7 +222,8 @@ public static TableFinishNode createTemporaryTableWriteWithoutExchanges( Optional.of(insertReference), outputVar, Optional.empty(), - Optional.empty()); + Optional.empty(), + cteId.map(CteMaterializationInfo::new)); } public static TableFinishNode createTemporaryTableWriteWithExchanges( @@ -353,7 +361,6 @@ public static TableFinishNode createTemporaryTableWriteWithExchanges( variableAllocator.newVariable("intermediatefragments", VARBINARY), variableAllocator.newVariable("intermediatetablecommitcontext", VARBINARY), enableStatsCollectionForTemporaryTable ? Optional.of(localAggregations.getIntermediateAggregation()) : Optional.empty()); - return new TableFinishNode( sourceLocation, idAllocator.getNextId(), @@ -364,7 +371,8 @@ public static TableFinishNode createTemporaryTableWriteWithExchanges( Optional.of(insertReference), variableAllocator.newVariable("rows", BIGINT), enableStatsCollectionForTemporaryTable ? Optional.of(aggregations.getFinalAggregation()) : Optional.empty(), - enableStatsCollectionForTemporaryTable ? Optional.of(statisticsResult.getDescriptor()) : Optional.empty()); + enableStatsCollectionForTemporaryTable ? Optional.of(statisticsResult.getDescriptor()) : Optional.empty(), + Optional.empty()); } public static StatisticAggregations.Parts splitIntoPartialAndFinal(StatisticAggregations statisticAggregations, VariableAllocator variableAllocator, FunctionAndTypeManager functionAndTypeManager) diff --git a/presto-main/src/main/java/com/facebook/presto/sql/analyzer/FeaturesConfig.java b/presto-main/src/main/java/com/facebook/presto/sql/analyzer/FeaturesConfig.java index d6c329d195d90..4ee4907dc1cb2 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/analyzer/FeaturesConfig.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/analyzer/FeaturesConfig.java @@ -138,6 +138,7 @@ public class FeaturesConfig private boolean ignoreStatsCalculatorFailures = true; private boolean printStatsForNonJoinQuery; private boolean defaultFilterFactorEnabled; + private boolean enhancedCteSchedulingEnabled = true; // Give a default 10% selectivity coefficient factor to avoid hitting unknown stats in join stats estimates // which could result in syntactic join order. Set it to 0 to disable this feature private double defaultJoinSelectivityCoefficient; @@ -1290,6 +1291,18 @@ public boolean isDefaultFilterFactorEnabled() return defaultFilterFactorEnabled; } + @Config("enhanced-cte-scheduling-enabled") + public FeaturesConfig setEnhancedCTESchedulingEnabled(boolean enhancedCTEBlockingEnabled) + { + this.enhancedCteSchedulingEnabled = enhancedCTEBlockingEnabled; + return this; + } + + public boolean getEnhancedCTESchedulingEnabled() + { + return enhancedCteSchedulingEnabled; + } + @Config("optimizer.default-join-selectivity-coefficient") @ConfigDescription("Used when join selectivity estimation is unknown. Default 0 to disable the use of join selectivity, this will allow planner to fall back to FROM-clause join order when the join cardinality is unknown") public FeaturesConfig setDefaultJoinSelectivityCoefficient(double defaultJoinSelectivityCoefficient) diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/BasePlanFragmenter.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/BasePlanFragmenter.java index b31906af0df5e..348e031539142 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/BasePlanFragmenter.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/BasePlanFragmenter.java @@ -386,7 +386,8 @@ private PlanNode createRemoteMaterializedExchange(ExchangeNode exchange, Rewrite temporaryTableHandle, exchange.getOutputVariables(), variableToColumnMap, - Optional.of(partitioningMetadata)); + Optional.of(partitioningMetadata), + Optional.empty()); checkArgument( !exchange.getPartitioningScheme().isReplicateNullsAndAny(), diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/CanonicalPlanGenerator.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/CanonicalPlanGenerator.java index 9ffc5fbeda66d..5825cbeca2dd6 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/CanonicalPlanGenerator.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/CanonicalPlanGenerator.java @@ -252,7 +252,8 @@ public Optional<PlanNode> visitTableFinish(TableFinishNode node, Context context node.getTarget().map(target -> CanonicalWriterTarget.from(target)), node.getRowCountVariable(), Optional.empty(), - Optional.empty()); + Optional.empty(), + node.getCteMaterializationInfo()); context.addPlan(node, new CanonicalPlan(result, strategy)); return Optional.of(result); } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/LogicalPlanner.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/LogicalPlanner.java index ccc459bc0ceb5..54d0fdd3708b6 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/LogicalPlanner.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/LogicalPlanner.java @@ -240,7 +240,7 @@ private RelationPlan createAnalyzePlan(Analysis analysis, Analyze analyzeStateme .putAll(tableScanOutputs.stream().collect(toImmutableMap(identity(), identity()))) .putAll(tableStatisticAggregation.getAdditionalVariables()) .build(); - TableScanNode scanNode = new TableScanNode(getSourceLocation(analyzeStatement), idAllocator.getNextId(), targetTable, tableScanOutputs, variableToColumnHandle.build(), TupleDomain.all(), TupleDomain.all()); + TableScanNode scanNode = new TableScanNode(getSourceLocation(analyzeStatement), idAllocator.getNextId(), targetTable, tableScanOutputs, variableToColumnHandle.build(), TupleDomain.all(), TupleDomain.all(), Optional.empty()); PlanNode project = PlannerUtils.addProjections(scanNode, idAllocator, assignments); PlanNode planNode = new StatisticsWriterNode( getSourceLocation(analyzeStatement), @@ -442,7 +442,8 @@ private RelationPlan createTableWriterPlan( // final aggregation is run within the TableFinishOperator to summarize collected statistics // by the partial aggregation from all of the writer nodes Optional.of(aggregations.getFinalAggregation()), - Optional.of(result.getDescriptor())); + Optional.of(result.getDescriptor()), + Optional.empty()); return new RelationPlan(commitNode, analysis.getRootScope(), commitNode.getOutputVariables()); } @@ -468,6 +469,7 @@ private RelationPlan createTableWriterPlan( Optional.of(target), variableAllocator.newVariable("rows", BIGINT), Optional.empty(), + Optional.empty(), Optional.empty()); return new RelationPlan(commitNode, analysis.getRootScope(), commitNode.getOutputVariables()); } @@ -487,6 +489,7 @@ private RelationPlan createDeletePlan(Analysis analysis, Delete node) Optional.of(deleteHandle), variableAllocator.newVariable("rows", BIGINT), Optional.empty(), + Optional.empty(), Optional.empty()); return new RelationPlan(commitNode, analysis.getScope(node), commitNode.getOutputVariables()); @@ -528,6 +531,7 @@ private RelationPlan createUpdatePlan(Analysis analysis, Update node) Optional.of(updateTarget), variableAllocator.newVariable("rows", BIGINT), Optional.empty(), + Optional.empty(), Optional.empty()); return new RelationPlan(commitNode, analysis.getScope(node), commitNode.getOutputVariables()); diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/PlanFragmenterUtils.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/PlanFragmenterUtils.java index 90d1c80564a3e..5f267f4c94e22 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/PlanFragmenterUtils.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/PlanFragmenterUtils.java @@ -257,6 +257,7 @@ public static Set<PlanNodeId> getOutputTableWriterNodeIds(PlanNode plan) .filter(node -> node instanceof TableWriterNode) .map(node -> (TableWriterNode) node) .filter(tableWriterNode -> !tableWriterNode.getIsTemporaryTableWriter().orElse(false)) + .map(node -> (TableWriterNode) node) .map(TableWriterNode::getId) .collect(toImmutableSet()); } @@ -306,7 +307,8 @@ public PlanNode visitTableScan(TableScanNode node, RewriteContext<Void> context) node.getOutputVariables(), node.getAssignments(), node.getCurrentConstraint(), - node.getEnforcedConstraint()); + node.getEnforcedConstraint(), + node.getCteMaterializationInfo()); } } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/PlannerUtils.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/PlannerUtils.java index bd21f4234b272..a4fc4b9f656a5 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/PlannerUtils.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/PlannerUtils.java @@ -357,7 +357,7 @@ private static TableScanNode cloneTableScan(TableScanNode scanNode, Session sess newAssignments, scanNode.getTableConstraints(), scanNode.getCurrentConstraint(), - scanNode.getEnforcedConstraint()); + scanNode.getEnforcedConstraint(), scanNode.getCteMaterializationInfo()); } public static PlanNode clonePlanNode(PlanNode planNode, Session session, Metadata metadata, PlanNodeIdAllocator planNodeIdAllocator, List<VariableReferenceExpression> fieldsToKeep, Map<VariableReferenceExpression, VariableReferenceExpression> varMap) diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/QueryPlanner.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/QueryPlanner.java index 35487e1bf9984..023eb52fe1623 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/QueryPlanner.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/QueryPlanner.java @@ -275,7 +275,7 @@ public DeleteNode plan(Delete node) // create table scan List<VariableReferenceExpression> outputVariables = outputVariablesBuilder.build(); - PlanNode tableScan = new TableScanNode(getSourceLocation(node), idAllocator.getNextId(), handle, outputVariables, columns.build(), TupleDomain.all(), TupleDomain.all()); + PlanNode tableScan = new TableScanNode(getSourceLocation(node), idAllocator.getNextId(), handle, outputVariables, columns.build(), TupleDomain.all(), TupleDomain.all(), Optional.empty()); Scope scope = Scope.builder().withRelationType(RelationId.anonymous(), new RelationType(fields.build())).build(); RelationPlan relationPlan = new RelationPlan(tableScan, scope, outputVariables); @@ -344,7 +344,7 @@ public UpdateNode plan(Update node) // create table scan List<VariableReferenceExpression> outputVariables = outputVariablesBuilder.build(); - PlanNode tableScan = new TableScanNode(getSourceLocation(node), idAllocator.getNextId(), handle, outputVariables, columns.build(), TupleDomain.all(), TupleDomain.all()); + PlanNode tableScan = new TableScanNode(getSourceLocation(node), idAllocator.getNextId(), handle, outputVariables, columns.build(), TupleDomain.all(), TupleDomain.all(), Optional.empty()); Scope scope = Scope.builder().withRelationType(RelationId.anonymous(), new RelationType(fields.build())).build(); RelationPlan relationPlan = new RelationPlan(tableScan, scope, outputVariables); diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/RelationPlanner.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/RelationPlanner.java index 1fb46f320c931..e260ea6b98d82 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/RelationPlanner.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/RelationPlanner.java @@ -219,7 +219,8 @@ protected RelationPlan visitTable(Table node, SqlPlannerContext context) List<VariableReferenceExpression> outputVariables = outputVariablesBuilder.build(); List<TableConstraint<ColumnHandle>> tableConstraints = metadata.getTableMetadata(session, handle).getMetadata().getTableConstraintsHolder().getTableConstraintsWithColumnHandles(); context.incrementLeafNodes(session); - PlanNode root = new TableScanNode(getSourceLocation(node.getLocation()), idAllocator.getNextId(), handle, outputVariables, columns.build(), tableConstraints, TupleDomain.all(), TupleDomain.all()); + PlanNode root = new TableScanNode(getSourceLocation(node.getLocation()), idAllocator.getNextId(), handle, outputVariables, columns.build(), + tableConstraints, TupleDomain.all(), TupleDomain.all(), Optional.empty()); return new RelationPlan(root, scope, outputVariables); } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PickTableLayout.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PickTableLayout.java index 609b7e740bbcb..faf96d27adabc 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PickTableLayout.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PickTableLayout.java @@ -228,7 +228,8 @@ public Result apply(TableScanNode tableScanNode, Captures captures, Context cont tableScanNode.getAssignments(), tableScanNode.getTableConstraints(), layout.getLayout().getPredicate(), - TupleDomain.all())); + TupleDomain.all(), + tableScanNode.getCteMaterializationInfo())); } } @@ -324,7 +325,8 @@ private static PlanNode pushPredicateIntoTableScan( node.getAssignments(), node.getTableConstraints(), layout.getLayout().getPredicate(), - computeEnforced(newDomain, layout.getUnenforcedConstraint())); + computeEnforced(newDomain, layout.getUnenforcedConstraint()), + node.getCteMaterializationInfo()); // The order of the arguments to combineConjuncts matters: // * Unenforced constraints go first because they can only be simple column references, diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PruneTableScanColumns.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PruneTableScanColumns.java index 7225e5e183ac9..fef6aac344c20 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PruneTableScanColumns.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PruneTableScanColumns.java @@ -46,6 +46,6 @@ protected Optional<PlanNode> pushDownProjectOff(PlanNodeIdAllocator idAllocator, filterKeys(tableScanNode.getAssignments(), referencedOutputs::contains), tableScanNode.getTableConstraints(), tableScanNode.getCurrentConstraint(), - tableScanNode.getEnforcedConstraint())); + tableScanNode.getEnforcedConstraint(), tableScanNode.getCteMaterializationInfo())); } } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/RowExpressionRewriteRuleSet.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/RowExpressionRewriteRuleSet.java index f448b8024fcb1..e969326ecfff5 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/RowExpressionRewriteRuleSet.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/RowExpressionRewriteRuleSet.java @@ -524,7 +524,8 @@ public Result apply(TableFinishNode node, Captures captures, Context context) node.getTarget(), node.getRowCountVariable(), rewrittenStatisticsAggregation, - node.getStatisticsAggregationDescriptor())); + node.getStatisticsAggregationDescriptor(), + node.getCteMaterializationInfo())); } return Result.empty(); } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/PhysicalCteOptimizer.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/PhysicalCteOptimizer.java index 17a93b627bba0..91f2c36d1099f 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/PhysicalCteOptimizer.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/PhysicalCteOptimizer.java @@ -138,7 +138,8 @@ public PlanNode visitCteProducer(CteProducerNode node, RewriteContext<PhysicalCt temporaryTableHandle, actualSource.getOutputVariables(), variableToColumnMap, - Optional.empty()), node.getOutputVariables())); + Optional.empty(), + Optional.of(node.getCteId())), node.getOutputVariables())); } catch (PrestoException e) { if (e.getErrorCode().equals(NOT_SUPPORTED.toErrorCode())) { @@ -159,7 +160,8 @@ public PlanNode visitCteProducer(CteProducerNode node, RewriteContext<PhysicalCt temporaryTableHandle, actualSource.getOutputVariables(), variableToColumnMap, - node.getRowCountVariable()); + node.getRowCountVariable(), + Optional.of(node.getCteId())); } public boolean isPlanRewritten() @@ -210,7 +212,8 @@ public PlanNode visitCteConsumer(CteConsumerNode node, RewriteContext<PhysicalCt newOutputVariables, newColumnAssignmentsMap, tempScan.getCurrentConstraint(), - tempScan.getEnforcedConstraint()); + tempScan.getEnforcedConstraint(), + tempScan.getCteMaterializationInfo()); // The temporary table scan might have columns removed by the UnaliasSymbolReferences and other optimizers (its a plan tree after all), // use originalOutputVariables (which are also canonicalized and maintained) and add them back diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/PruneUnreferencedOutputs.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/PruneUnreferencedOutputs.java index 90510f8d111d4..2d8e8e5737a0f 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/PruneUnreferencedOutputs.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/PruneUnreferencedOutputs.java @@ -483,7 +483,8 @@ public PlanNode visitTableScan(TableScanNode node, RewriteContext<Set<VariableRe newAssignments, node.getTableConstraints(), node.getCurrentConstraint(), - node.getEnforcedConstraint()); + node.getEnforcedConstraint(), + node.getCteMaterializationInfo()); } @Override @@ -796,7 +797,8 @@ public PlanNode visitTableFinish(TableFinishNode node, RewriteContext<Set<Variab node.getTarget(), node.getRowCountVariable(), node.getStatisticsAggregation(), - node.getStatisticsAggregationDescriptor()); + node.getStatisticsAggregationDescriptor(), + node.getCteMaterializationInfo()); } @Override diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/PushdownSubfields.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/PushdownSubfields.java index ada7285410e25..ae191bb75b175 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/PushdownSubfields.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/PushdownSubfields.java @@ -399,7 +399,7 @@ public PlanNode visitTableScan(TableScanNode node, RewriteContext<Context> conte newAssignments.build(), node.getTableConstraints(), node.getCurrentConstraint(), - node.getEnforcedConstraint()); + node.getEnforcedConstraint(), node.getCteMaterializationInfo()); } @Override diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/SymbolMapper.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/SymbolMapper.java index 9ddb6f97662ee..9805efad17939 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/SymbolMapper.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/SymbolMapper.java @@ -283,7 +283,8 @@ public TableFinishNode map(TableFinishNode node, PlanNode source) node.getTarget(), map(node.getRowCountVariable()), node.getStatisticsAggregation().map(this::map), - node.getStatisticsAggregationDescriptor().map(descriptor -> descriptor.map(this::map))); + node.getStatisticsAggregationDescriptor().map(descriptor -> descriptor.map(this::map)), + node.getCteMaterializationInfo()); } public TableWriterMergeNode map(TableWriterMergeNode node, PlanNode source) diff --git a/presto-main/src/test/java/com/facebook/presto/cost/TestCostCalculator.java b/presto-main/src/test/java/com/facebook/presto/cost/TestCostCalculator.java index be87bddc7b209..326c79a64469f 100644 --- a/presto-main/src/test/java/com/facebook/presto/cost/TestCostCalculator.java +++ b/presto-main/src/test/java/com/facebook/presto/cost/TestCostCalculator.java @@ -921,7 +921,7 @@ private TableScanNode tableScan(String id, List<VariableReferenceExpression> var variables, assignments.build(), TupleDomain.all(), - TupleDomain.all()); + TupleDomain.all(), Optional.empty()); } private PlanNode project(String id, PlanNode source, VariableReferenceExpression variable, RowExpression expression) diff --git a/presto-main/src/test/java/com/facebook/presto/execution/MockRemoteTaskFactory.java b/presto-main/src/test/java/com/facebook/presto/execution/MockRemoteTaskFactory.java index 8731395f010a9..2ae96279b15b7 100644 --- a/presto-main/src/test/java/com/facebook/presto/execution/MockRemoteTaskFactory.java +++ b/presto-main/src/test/java/com/facebook/presto/execution/MockRemoteTaskFactory.java @@ -124,7 +124,7 @@ public MockRemoteTask createTableScanTask(TaskId taskId, InternalNode newNode, L ImmutableList.of(variable), ImmutableMap.of(variable, new TestingColumnHandle("column")), TupleDomain.all(), - TupleDomain.all()), + TupleDomain.all(), Optional.empty()), ImmutableSet.of(variable), SOURCE_DISTRIBUTION, ImmutableList.of(sourceId), diff --git a/presto-main/src/test/java/com/facebook/presto/execution/TaskTestUtils.java b/presto-main/src/test/java/com/facebook/presto/execution/TaskTestUtils.java index de23ba82b0236..cdee04d9bdc83 100644 --- a/presto-main/src/test/java/com/facebook/presto/execution/TaskTestUtils.java +++ b/presto-main/src/test/java/com/facebook/presto/execution/TaskTestUtils.java @@ -120,7 +120,7 @@ public static PlanFragment createPlanFragment() ImmutableList.of(VARIABLE), ImmutableMap.of(VARIABLE, new TestingColumnHandle("column", 0, BIGINT)), TupleDomain.all(), - TupleDomain.all()), + TupleDomain.all(), Optional.empty()), ImmutableSet.of(VARIABLE), SOURCE_DISTRIBUTION, ImmutableList.of(TABLE_SCAN_NODE_ID), diff --git a/presto-main/src/test/java/com/facebook/presto/execution/scheduler/TestPhasedExecutionSchedule.java b/presto-main/src/test/java/com/facebook/presto/execution/scheduler/TestPhasedExecutionSchedule.java index 301d6559a1d95..8c932566c4453 100644 --- a/presto-main/src/test/java/com/facebook/presto/execution/scheduler/TestPhasedExecutionSchedule.java +++ b/presto-main/src/test/java/com/facebook/presto/execution/scheduler/TestPhasedExecutionSchedule.java @@ -206,7 +206,7 @@ private static PlanFragment createBroadcastJoinPlanFragment(String name, PlanFra ImmutableList.of(variable), ImmutableMap.of(variable, new TestingColumnHandle("column")), TupleDomain.all(), - TupleDomain.all()); + TupleDomain.all(), Optional.empty()); RemoteSourceNode remote = new RemoteSourceNode(Optional.empty(), new PlanNodeId("build_id"), buildFragment.getId(), ImmutableList.of(), false, Optional.empty(), REPLICATE); PlanNode join = new JoinNode( @@ -266,7 +266,7 @@ private static PlanFragment createTableScanPlanFragment(String name) ImmutableList.of(variable), ImmutableMap.of(variable, new TestingColumnHandle("column")), TupleDomain.all(), - TupleDomain.all()); + TupleDomain.all(), Optional.empty()); return createFragment(planNode); } diff --git a/presto-main/src/test/java/com/facebook/presto/execution/scheduler/TestSourcePartitionedScheduler.java b/presto-main/src/test/java/com/facebook/presto/execution/scheduler/TestSourcePartitionedScheduler.java index dba204ed39644..00c4483a755c5 100644 --- a/presto-main/src/test/java/com/facebook/presto/execution/scheduler/TestSourcePartitionedScheduler.java +++ b/presto-main/src/test/java/com/facebook/presto/execution/scheduler/TestSourcePartitionedScheduler.java @@ -489,7 +489,8 @@ private static SubPlan createPlan() ImmutableList.of(variable), ImmutableMap.of(variable, new TestingColumnHandle("column")), TupleDomain.all(), - TupleDomain.all()); + TupleDomain.all(), + Optional.empty()); RemoteSourceNode remote = new RemoteSourceNode(Optional.empty(), new PlanNodeId("remote_id"), new PlanFragmentId(0), ImmutableList.of(), false, Optional.empty(), GATHER); PlanFragment testFragment = new PlanFragment( diff --git a/presto-main/src/test/java/com/facebook/presto/operator/TestDriver.java b/presto-main/src/test/java/com/facebook/presto/operator/TestDriver.java index 6dcf637834938..f1b1256596292 100644 --- a/presto-main/src/test/java/com/facebook/presto/operator/TestDriver.java +++ b/presto-main/src/test/java/com/facebook/presto/operator/TestDriver.java @@ -113,7 +113,7 @@ public class TestDriver ImmutableList.of(), ImmutableMap.of(), TupleDomain.all(), - TupleDomain.all()), + TupleDomain.all(), Optional.empty()), ImmutableMap.of(), singleGroupingSet(ImmutableList.of()), ImmutableList.of(), diff --git a/presto-main/src/test/java/com/facebook/presto/sql/analyzer/TestFeaturesConfig.java b/presto-main/src/test/java/com/facebook/presto/sql/analyzer/TestFeaturesConfig.java index 120a48ca62fdc..55f0e2e5e254a 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/analyzer/TestFeaturesConfig.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/analyzer/TestFeaturesConfig.java @@ -250,7 +250,8 @@ public void testDefaults() .setEagerPlanValidationThreadPoolSize(20) .setPrestoSparkExecutionEnvironment(false) .setSingleNodeExecutionEnabled(false) - .setNativeExecutionScaleWritersThreadsEnabled(false)); + .setNativeExecutionScaleWritersThreadsEnabled(false) + .setEnhancedCTESchedulingEnabled(true)); } @Test @@ -450,6 +451,7 @@ public void testExplicitPropertyMappings() .put("presto-spark-execution-environment", "true") .put("single-node-execution-enabled", "true") .put("native-execution-scale-writer-threads-enabled", "true") + .put("enhanced-cte-scheduling-enabled", "false") .build(); FeaturesConfig expected = new FeaturesConfig() @@ -646,7 +648,8 @@ public void testExplicitPropertyMappings() .setEagerPlanValidationThreadPoolSize(2) .setPrestoSparkExecutionEnvironment(true) .setSingleNodeExecutionEnabled(true) - .setNativeExecutionScaleWritersThreadsEnabled(true); + .setNativeExecutionScaleWritersThreadsEnabled(true) + .setEnhancedCTESchedulingEnabled(false); assertFullMapping(properties, expected); } diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/TestCanonicalPlanGenerator.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/TestCanonicalPlanGenerator.java index e38e5c4eae0c5..d64e462429046 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/TestCanonicalPlanGenerator.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/TestCanonicalPlanGenerator.java @@ -375,7 +375,7 @@ public void testCanonicalTableScanNodeField() .filter(f -> !f.isSynthetic()) .map(Field::getName) .collect(toImmutableSet()), - ImmutableSet.of("table", "assignments", "outputVariables", "currentConstraint", "enforcedConstraint", "tableConstraints")); + ImmutableSet.of("table", "assignments", "outputVariables", "currentConstraint", "enforcedConstraint", "tableConstraints", "cteMaterializationInfo")); assertEquals( Arrays.stream(CanonicalTableScanNode.class.getDeclaredFields()) .filter(f -> !f.isSynthetic()) diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/TestEffectivePredicateExtractor.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/TestEffectivePredicateExtractor.java index 0d3954f0ce6aa..46676111a82aa 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/TestEffectivePredicateExtractor.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/TestEffectivePredicateExtractor.java @@ -144,7 +144,7 @@ public void setUp() ImmutableList.copyOf(assignments.keySet()), assignments, TupleDomain.all(), - TupleDomain.all()); + TupleDomain.all(), Optional.empty()); } @Test @@ -373,7 +373,7 @@ public void testTableScan() ImmutableList.copyOf(assignments.keySet()), assignments, TupleDomain.all(), - TupleDomain.all()); + TupleDomain.all(), Optional.empty()); RowExpression effectivePredicate = effectivePredicateExtractor.extract(node); assertEquals(effectivePredicate, TRUE_CONSTANT); @@ -384,7 +384,7 @@ public void testTableScan() ImmutableList.copyOf(assignments.keySet()), assignments, TupleDomain.none(), - TupleDomain.all()); + TupleDomain.all(), Optional.empty()); effectivePredicate = effectivePredicateExtractor.extract(node); assertEquals(effectivePredicate, FALSE_CONSTANT); @@ -395,7 +395,7 @@ public void testTableScan() ImmutableList.copyOf(assignments.keySet()), assignments, TupleDomain.withColumnDomains(ImmutableMap.of(scanAssignments.get(AV), Domain.singleValue(BIGINT, 1L))), - TupleDomain.all()); + TupleDomain.all(), Optional.empty()); effectivePredicate = effectivePredicateExtractor.extract(node); assertEquals(normalizeConjuncts(effectivePredicate), normalizeConjuncts(equals(bigintLiteral(1L), AV))); @@ -408,7 +408,7 @@ public void testTableScan() TupleDomain.withColumnDomains(ImmutableMap.of( scanAssignments.get(AV), Domain.singleValue(BIGINT, 1L), scanAssignments.get(BV), Domain.singleValue(BIGINT, 2L))), - TupleDomain.all()); + TupleDomain.all(), Optional.empty()); effectivePredicate = effectivePredicateExtractor.extract(node); assertEquals(normalizeConjuncts(effectivePredicate), normalizeConjuncts(equals(bigintLiteral(2L), BV), equals(bigintLiteral(1L), AV))); @@ -419,7 +419,7 @@ public void testTableScan() ImmutableList.copyOf(assignments.keySet()), assignments, TupleDomain.all(), - TupleDomain.all()); + TupleDomain.all(), Optional.empty()); effectivePredicate = effectivePredicateExtractor.extract(node); assertEquals(effectivePredicate, TRUE_CONSTANT); } @@ -783,7 +783,7 @@ private static TableScanNode tableScanNode(Map<VariableReferenceExpression, Colu ImmutableList.copyOf(scanAssignments.keySet()), scanAssignments, TupleDomain.all(), - TupleDomain.all()); + TupleDomain.all(), Optional.empty()); } private static PlanNodeId newId() diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/TestLocalExecutionPlanner.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/TestLocalExecutionPlanner.java index b7873b75f87fa..1047d45985fe2 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/TestLocalExecutionPlanner.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/TestLocalExecutionPlanner.java @@ -162,7 +162,7 @@ public void testCustomPlanTranslator() ImmutableList.of(variable), ImmutableMap.of(variable, new TestingMetadata.TestingColumnHandle("column")), TupleDomain.all(), - TupleDomain.all()); + TupleDomain.all(), Optional.empty()); PlanNode node1 = new CustomNodeA(new PlanNodeId("node1"), scan); PlanNode node2 = new CustomNodeB(new PlanNodeId("node2"), node1); diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/TestTypeValidator.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/TestTypeValidator.java index f11154453ab41..07a0529289495 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/TestTypeValidator.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/TestTypeValidator.java @@ -121,7 +121,7 @@ public void setUp() ImmutableList.copyOf(assignments.keySet()), assignments, TupleDomain.all(), - TupleDomain.all()); + TupleDomain.all(), Optional.empty()); } @Test diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/test/PlanBuilder.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/test/PlanBuilder.java index ec4dfae373914..aa633c33a8624 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/test/PlanBuilder.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/test/PlanBuilder.java @@ -545,7 +545,8 @@ public TableScanNode tableScan( assignments, ImmutableList.of(), currentConstraint, - enforcedConstraint); + enforcedConstraint, + Optional.empty()); } public TableScanNode tableScan( @@ -564,7 +565,7 @@ public TableScanNode tableScan( assignments, tableConstraints, currentConstraint, - enforcedConstraint); + enforcedConstraint, Optional.empty()); } public TableFinishNode tableDelete(SchemaTableName schemaTableName, PlanNode deleteSource, VariableReferenceExpression deleteRowId) @@ -592,7 +593,7 @@ public TableFinishNode tableDelete(SchemaTableName schemaTableName, PlanNode del Optional.of(deleteHandle), deleteRowId, Optional.empty(), - Optional.empty()); + Optional.empty(), Optional.empty()); } public ExchangeNode gatheringExchange(ExchangeNode.Scope scope, PlanNode child) diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/optimizations/TestConnectorOptimization.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/optimizations/TestConnectorOptimization.java index 393d363cc09b1..62b96f0973778 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/optimizations/TestConnectorOptimization.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/optimizations/TestConnectorOptimization.java @@ -338,7 +338,7 @@ public PlanNode visitFilter(FilterNode node, Void context) tableScanNode.getAssignments(), tableScanNode.getTableConstraints(), TupleDomain.all(), - TupleDomain.all()); + TupleDomain.all(), Optional.empty()); } return node; } diff --git a/presto-main/src/test/java/com/facebook/presto/util/TestGraphvizPrinter.java b/presto-main/src/test/java/com/facebook/presto/util/TestGraphvizPrinter.java index 7f496a5e41c92..e04f3a67bdabb 100644 --- a/presto-main/src/test/java/com/facebook/presto/util/TestGraphvizPrinter.java +++ b/presto-main/src/test/java/com/facebook/presto/util/TestGraphvizPrinter.java @@ -67,7 +67,7 @@ public class TestGraphvizPrinter ImmutableList.of(), ImmutableMap.of(), TupleDomain.all(), - TupleDomain.all()); + TupleDomain.all(), Optional.empty()); private static final String TEST_TABLE_SCAN_NODE_INNER_OUTPUT = format( "label=\"{TableScan | [TableHandle \\{connectorId='%s', connectorHandle='%s', layout='Optional.empty'\\}]|Estimates: \\{rows: ? (0B), cpu: ?, memory: ?, network: ?\\}\n" + "}\", style=\"rounded, filled\", shape=record, fillcolor=deepskyblue", diff --git a/presto-parquet/src/main/java/com/facebook/presto/parquet/rule/ParquetDereferencePushDown.java b/presto-parquet/src/main/java/com/facebook/presto/parquet/rule/ParquetDereferencePushDown.java index 17f156cc0fff5..bd4d89c4fb034 100644 --- a/presto-parquet/src/main/java/com/facebook/presto/parquet/rule/ParquetDereferencePushDown.java +++ b/presto-parquet/src/main/java/com/facebook/presto/parquet/rule/ParquetDereferencePushDown.java @@ -371,7 +371,8 @@ public PlanNode visitProject(ProjectNode project, RewriteContext<Void> context) newAssignments, tableScan.getTableConstraints(), tableScan.getCurrentConstraint(), - tableScan.getEnforcedConstraint()); + tableScan.getEnforcedConstraint(), + tableScan.getCteMaterializationInfo()); Assignments.Builder newProjectAssignmentBuilder = Assignments.builder(); for (Map.Entry<VariableReferenceExpression, RowExpression> entry : project.getAssignments().entrySet()) { diff --git a/presto-pinot-toolkit/src/main/java/com/facebook/presto/pinot/PinotPlanOptimizer.java b/presto-pinot-toolkit/src/main/java/com/facebook/presto/pinot/PinotPlanOptimizer.java index 46911541b4eee..0e437490b26e2 100644 --- a/presto-pinot-toolkit/src/main/java/com/facebook/presto/pinot/PinotPlanOptimizer.java +++ b/presto-pinot-toolkit/src/main/java/com/facebook/presto/pinot/PinotPlanOptimizer.java @@ -184,7 +184,8 @@ private Optional<PlanNode> tryCreatingNewScanNode(PlanNode plan, TableScanNode t ImmutableList.copyOf(assignments.keySet()), assignments.entrySet().stream().collect(toImmutableMap(Map.Entry::getKey, (e) -> (ColumnHandle) (e.getValue()))), tableScanNode.getCurrentConstraint(), - tableScanNode.getEnforcedConstraint())); + tableScanNode.getEnforcedConstraint(), + tableScanNode.getCteMaterializationInfo())); } @Override diff --git a/presto-spark-base/src/test/java/com/facebook/presto/spark/planner/TestIterativePlanFragmenter.java b/presto-spark-base/src/test/java/com/facebook/presto/spark/planner/TestIterativePlanFragmenter.java index 379f912e8cba1..250b970b511ad 100644 --- a/presto-spark-base/src/test/java/com/facebook/presto/spark/planner/TestIterativePlanFragmenter.java +++ b/presto-spark-base/src/test/java/com/facebook/presto/spark/planner/TestIterativePlanFragmenter.java @@ -312,7 +312,7 @@ private TableScanNode tableScan(String id, List<VariableReferenceExpression> var variables, assignments.build(), TupleDomain.all(), - TupleDomain.all()); + TupleDomain.all(), Optional.empty()); } private PlanNode project(String id, PlanNode source, VariableReferenceExpression variable, RowExpression expression) diff --git a/presto-spi/src/main/java/com/facebook/presto/spi/plan/CteMaterializationInfo.java b/presto-spi/src/main/java/com/facebook/presto/spi/plan/CteMaterializationInfo.java new file mode 100644 index 0000000000000..8954005e35eec --- /dev/null +++ b/presto-spi/src/main/java/com/facebook/presto/spi/plan/CteMaterializationInfo.java @@ -0,0 +1,32 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.spi.plan; + +/* + * Contains information about the identifier (cteId) of the CTE being materialized. This information is stored tablescans and tablefinish plan nodesg + */ +public class CteMaterializationInfo +{ + private final String cteId; + + public CteMaterializationInfo(String cteId) + { + this.cteId = cteId; + } + + public String getCteId() + { + return cteId; + } +} diff --git a/presto-spi/src/main/java/com/facebook/presto/spi/plan/TableFinishNode.java b/presto-spi/src/main/java/com/facebook/presto/spi/plan/TableFinishNode.java index 05f153cdb9e5b..281c6dec7c148 100644 --- a/presto-spi/src/main/java/com/facebook/presto/spi/plan/TableFinishNode.java +++ b/presto-spi/src/main/java/com/facebook/presto/spi/plan/TableFinishNode.java @@ -38,6 +38,8 @@ public final class TableFinishNode private final Optional<StatisticAggregations> statisticsAggregation; private final Optional<StatisticAggregationsDescriptor<VariableReferenceExpression>> statisticsAggregationDescriptor; + private final Optional<CteMaterializationInfo> temporaryTableInfo; + @JsonCreator public TableFinishNode( Optional<SourceLocation> sourceLocation, @@ -46,9 +48,10 @@ public TableFinishNode( @JsonProperty("target") Optional<TableWriterNode.WriterTarget> target, @JsonProperty("rowCountVariable") VariableReferenceExpression rowCountVariable, @JsonProperty("statisticsAggregation") Optional<StatisticAggregations> statisticsAggregation, - @JsonProperty("statisticsAggregationDescriptor") Optional<StatisticAggregationsDescriptor<VariableReferenceExpression>> statisticsAggregationDescriptor) + @JsonProperty("statisticsAggregationDescriptor") Optional<StatisticAggregationsDescriptor<VariableReferenceExpression>> statisticsAggregationDescriptor, + @JsonProperty("cteMaterializationInfo") Optional<CteMaterializationInfo> temporaryTableInfo) { - this(sourceLocation, id, Optional.empty(), source, target, rowCountVariable, statisticsAggregation, statisticsAggregationDescriptor); + this(sourceLocation, id, Optional.empty(), source, target, rowCountVariable, statisticsAggregation, statisticsAggregationDescriptor, temporaryTableInfo); } public TableFinishNode( @@ -59,11 +62,13 @@ public TableFinishNode( Optional<TableWriterNode.WriterTarget> target, VariableReferenceExpression rowCountVariable, Optional<StatisticAggregations> statisticsAggregation, - Optional<StatisticAggregationsDescriptor<VariableReferenceExpression>> statisticsAggregationDescriptor) + Optional<StatisticAggregationsDescriptor<VariableReferenceExpression>> statisticsAggregationDescriptor, + Optional<CteMaterializationInfo> temporaryTableInfo) { super(sourceLocation, id, statsEquivalentPlanNode); checkArgument(target != null || source instanceof TableWriterNode); + this.temporaryTableInfo = temporaryTableInfo; this.source = requireNonNull(source, "source is null"); this.target = requireNonNull(target, "target is null"); this.rowCountVariable = requireNonNull(rowCountVariable, "rowCountVariable is null"); @@ -120,6 +125,11 @@ public <R, C> R accept(PlanVisitor<R, C> visitor, C context) return visitor.visitTableFinish(this, context); } + public Optional<CteMaterializationInfo> getCteMaterializationInfo() + { + return temporaryTableInfo; + } + @Override public PlanNode replaceChildren(List<PlanNode> newChildren) { @@ -132,7 +142,8 @@ public PlanNode replaceChildren(List<PlanNode> newChildren) target, rowCountVariable, statisticsAggregation, - statisticsAggregationDescriptor); + statisticsAggregationDescriptor, + temporaryTableInfo); } @Override @@ -146,6 +157,7 @@ public PlanNode assignStatsEquivalentPlanNode(Optional<PlanNode> statsEquivalent target, rowCountVariable, statisticsAggregation, - statisticsAggregationDescriptor); + statisticsAggregationDescriptor, + temporaryTableInfo); } } diff --git a/presto-spi/src/main/java/com/facebook/presto/spi/plan/TableScanNode.java b/presto-spi/src/main/java/com/facebook/presto/spi/plan/TableScanNode.java index 17a739f855b14..ef6567103f544 100644 --- a/presto-spi/src/main/java/com/facebook/presto/spi/plan/TableScanNode.java +++ b/presto-spi/src/main/java/com/facebook/presto/spi/plan/TableScanNode.java @@ -50,6 +50,8 @@ public final class TableScanNode private final TupleDomain<ColumnHandle> enforcedConstraint; private final List<TableConstraint<ColumnHandle>> tableConstraints; + private final Optional<CteMaterializationInfo> cteMaterializationInfo; + /** * This constructor is for JSON deserialization only. Do not use! */ @@ -69,6 +71,7 @@ public TableScanNode( this.currentConstraint = null; this.enforcedConstraint = null; this.tableConstraints = emptyList(); + this.cteMaterializationInfo = Optional.empty(); } public TableScanNode( @@ -78,9 +81,9 @@ public TableScanNode( List<VariableReferenceExpression> outputVariables, Map<VariableReferenceExpression, ColumnHandle> assignments, TupleDomain<ColumnHandle> currentConstraint, - TupleDomain<ColumnHandle> enforcedConstraint) + TupleDomain<ColumnHandle> enforcedConstraint, Optional<CteMaterializationInfo> cteMaterializationInfo) { - this (sourceLocation, id, table, outputVariables, assignments, emptyList(), currentConstraint, enforcedConstraint); + this(sourceLocation, id, table, outputVariables, assignments, emptyList(), currentConstraint, enforcedConstraint, cteMaterializationInfo); } public TableScanNode( @@ -91,9 +94,9 @@ public TableScanNode( Map<VariableReferenceExpression, ColumnHandle> assignments, List<TableConstraint<ColumnHandle>> tableConstraints, TupleDomain<ColumnHandle> currentConstraint, - TupleDomain<ColumnHandle> enforcedConstraint) + TupleDomain<ColumnHandle> enforcedConstraint, Optional<CteMaterializationInfo> cteMaterializationInfo) { - this (sourceLocation, id, Optional.empty(), table, outputVariables, assignments, tableConstraints, currentConstraint, enforcedConstraint); + this(sourceLocation, id, Optional.empty(), table, outputVariables, assignments, tableConstraints, currentConstraint, enforcedConstraint, cteMaterializationInfo); } public TableScanNode( @@ -105,12 +108,14 @@ public TableScanNode( Map<VariableReferenceExpression, ColumnHandle> assignments, List<TableConstraint<ColumnHandle>> tableConstraints, TupleDomain<ColumnHandle> currentConstraint, - TupleDomain<ColumnHandle> enforcedConstraint) + TupleDomain<ColumnHandle> enforcedConstraint, + Optional<CteMaterializationInfo> cteMaterializationInfo) { super(sourceLocation, id, statsEquivalentPlanNode); this.table = requireNonNull(table, "table is null"); this.outputVariables = unmodifiableList(requireNonNull(outputVariables, "outputVariables is null")); this.assignments = unmodifiableMap(new HashMap<>(requireNonNull(assignments, "assignments is null"))); + this.cteMaterializationInfo = requireNonNull(cteMaterializationInfo, "cteMaterializationInfo is null"); checkArgument(assignments.keySet().containsAll(outputVariables), "assignments does not cover all of outputs"); this.currentConstraint = requireNonNull(currentConstraint, "currentConstraint is null"); this.enforcedConstraint = requireNonNull(enforcedConstraint, "enforcedConstraint is null"); @@ -120,6 +125,11 @@ public TableScanNode( this.tableConstraints = requireNonNull(tableConstraints, "tableConstraints is null"); } + public Optional<CteMaterializationInfo> getCteMaterializationInfo() + { + return cteMaterializationInfo; + } + /** * Get the table handle provided by connector */ @@ -206,7 +216,7 @@ public <R, C> R accept(PlanVisitor<R, C> visitor, C context) @Override public PlanNode assignStatsEquivalentPlanNode(Optional<PlanNode> statsEquivalentPlanNode) { - return new TableScanNode(getSourceLocation(), getId(), statsEquivalentPlanNode, table, outputVariables, assignments, tableConstraints, currentConstraint, enforcedConstraint); + return new TableScanNode(getSourceLocation(), getId(), statsEquivalentPlanNode, table, outputVariables, assignments, tableConstraints, currentConstraint, enforcedConstraint, cteMaterializationInfo); } @Override