Skip to content

Commit

Permalink
Pass temporary tableInfo to tableScan and tableWriter
Browse files Browse the repository at this point in the history
  • Loading branch information
jaystarshot committed Nov 21, 2024
1 parent ec60c13 commit 64a3b17
Show file tree
Hide file tree
Showing 53 changed files with 387 additions and 134 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ public PlanNode visitFilter(FilterNode node, RewriteContext<Void> context)
oldTableScanNode.getAssignments(),
oldTableScanNode.getTableConstraints(),
oldTableScanNode.getCurrentConstraint(),
oldTableScanNode.getEnforcedConstraint());
oldTableScanNode.getEnforcedConstraint(), Optional.empty());

return new FilterNode(node.getSourceLocation(), idAllocator.getNextId(), newTableScanNode, node.getPredicate());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,7 @@ private Optional<PlanNode> tryCreatingNewScanNode(PlanNode plan)
ImmutableList.copyOf(assignments.keySet()),
assignments.entrySet().stream().collect(toImmutableMap(Map.Entry::getKey, (e) -> (ColumnHandle) (e.getValue()))),
tableScanNode.getCurrentConstraint(),
tableScanNode.getEnforcedConstraint()));
tableScanNode.getEnforcedConstraint(), Optional.empty()));
}

@Override
Expand Down Expand Up @@ -288,7 +288,7 @@ public PlanNode visitFilter(FilterNode node, Void context)
oldTableScanNode.getOutputVariables(),
oldTableScanNode.getAssignments(),
oldTableScanNode.getCurrentConstraint(),
oldTableScanNode.getEnforcedConstraint());
oldTableScanNode.getEnforcedConstraint(), Optional.empty());

return new FilterNode(node.getSourceLocation(), idAllocator.getNextId(), newTableScanNode, node.getPredicate());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,7 @@ private Optional<PlanNode> tryCreatingNewScanNode(PlanNode plan)
assignments.entrySet().stream().collect(toImmutableMap(Map.Entry::getKey, (e) -> (ColumnHandle) (e.getValue()))),
tableScanNode.getTableConstraints(),
tableScanNode.getCurrentConstraint(),
tableScanNode.getEnforcedConstraint()));
tableScanNode.getEnforcedConstraint(), Optional.empty()));
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ protected TableScanNode tableScan(PlanBuilder planBuilder, DruidTableHandle conn
variables,
assignments.build(),
TupleDomain.all(),
TupleDomain.all());
TupleDomain.all(), Optional.empty());
}

protected FilterNode filter(PlanBuilder planBuilder, PlanNode source, RowExpression predicate)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -307,7 +307,7 @@ private static TableScanNode getTableScanNode(
tableScan.getAssignments(),
tableScan.getTableConstraints(),
pushdownFilterResult.getLayout().getPredicate(),
TupleDomain.all());
TupleDomain.all(), tableScan.getTemporaryTableInfo());
}

private static ExtractionResult intersectExtractionResult(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ public PlanNode visitTableScan(TableScanNode tableScan, RewriteContext<Void> con
tableScan.getAssignments(),
tableScan.getTableConstraints(),
tableScan.getCurrentConstraint(),
tableScan.getEnforcedConstraint());
tableScan.getEnforcedConstraint(), tableScan.getTemporaryTableInfo());
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -297,7 +297,7 @@ private Optional<PlanNode> tryPartialAggregationPushdown(PlanNode plan)
ImmutableMap.copyOf(assignments),
oldTableScanNode.getTableConstraints(),
oldTableScanNode.getCurrentConstraint(),
oldTableScanNode.getEnforcedConstraint()));
oldTableScanNode.getEnforcedConstraint(), oldTableScanNode.getTemporaryTableInfo()));
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ public void testPersistentCteWithTimeStampWithTimeZoneType()
" (CAST('2023-12-31 23:59:59.999 UTC' AS TIMESTAMP WITH TIME ZONE))" +
" ) AS t(ts)" +
")" +
"SELECT ts FROM cte";
"SELECT * FROM cte JOIN cte ON true";
QueryRunner queryRunner = getQueryRunner();
verifyResults(queryRunner, testQuery, ImmutableList.of(generateMaterializedCTEInformation("cte", 1, false, true)));
}
Expand Down Expand Up @@ -465,6 +465,7 @@ public void testPersistentCteWithVarbinary()
QueryRunner queryRunner = getQueryRunner();
verifyResults(queryRunner, testQuery, ImmutableList.of(generateMaterializedCTEInformation("dataset", 1, false, true)));
}

@Test
public void testComplexRefinedCtesOutsideScope()
{
Expand Down Expand Up @@ -706,7 +707,7 @@ public void testComplexChainOfDependentAndNestedPersistentCtes()
generateMaterializedCTEInformation("cte6", 1, false, true)));
}

@Test
@Test(enabled = false)
public void testComplexQuery1()
{
String testQuery = "WITH customer_nation AS (" +
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -341,7 +341,7 @@ private TableScanNode createDeletesTableScan(ImmutableMap<VariableReferenceExpre
outputs,
deleteColumnAssignments,
TupleDomain.all(),
TupleDomain.all());
TupleDomain.all(), Optional.empty());
}

/**
Expand Down Expand Up @@ -382,7 +382,8 @@ private TableScanNode createNewRoot(TableScanNode node, IcebergTableHandle icebe
assignmentsBuilder.build(),
node.getTableConstraints(),
node.getCurrentConstraint(),
node.getEnforcedConstraint());
node.getEnforcedConstraint(),
node.getTemporaryTableInfo());
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,7 @@ public PlanNode visitFilter(FilterNode filter, RewriteContext<Void> context)
.intersect(tableScan.getCurrentConstraint()),
predicateNotChangedBySimplification ?
identityPartitionColumnPredicate.intersect(tableScan.getEnforcedConstraint()) :
tableScan.getEnforcedConstraint());
tableScan.getEnforcedConstraint(), Optional.empty());

if (TRUE_CONSTANT.equals(remainingFilterExpression) && predicateNotChangedBySimplification) {
return newTableScan;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,16 @@
import com.facebook.presto.server.remotetask.HttpRemoteTask;
import com.facebook.presto.spi.PrestoException;
import com.facebook.presto.spi.plan.PlanFragmentId;
import com.facebook.presto.spi.plan.PlanNode;
import com.facebook.presto.spi.plan.PlanNodeId;
import com.facebook.presto.spi.plan.TableScanNode;
import com.facebook.presto.spi.plan.TemporaryTableInfo;
import com.facebook.presto.split.RemoteSplit;
import com.facebook.presto.sql.planner.PlanFragment;
import com.facebook.presto.sql.planner.optimizations.PlanNodeSearcher;
import com.facebook.presto.sql.planner.plan.RemoteSourceNode;
import com.facebook.presto.sql.planner.plan.TableFinishNode;
import com.facebook.presto.sql.planner.plan.TableWriterNode;
import com.google.common.collect.HashMultimap;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableMultimap;
Expand Down Expand Up @@ -60,6 +66,7 @@
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.Consumer;
import java.util.stream.Collectors;

import static com.facebook.presto.SystemSessionProperties.getMaxFailedTaskPercentage;
import static com.facebook.presto.failureDetector.FailureDetector.State.GONE;
Expand Down Expand Up @@ -112,6 +119,7 @@ public final class SqlStageExecution

@GuardedBy("this")
private final AtomicInteger nextTaskId = new AtomicInteger();

@GuardedBy("this")
private final Set<TaskId> allTasks = newConcurrentHashSet();
@GuardedBy("this")
Expand Down Expand Up @@ -557,7 +565,6 @@ private synchronized RemoteTask scheduleTask(InternalNode node, TaskId taskId, M
// stage finished while we were scheduling this task
task.abort();
}

return task;
}

Expand Down Expand Up @@ -594,6 +601,57 @@ private static Split createRemoteSplitFor(TaskId taskId, URI remoteSourceTaskLoc
return new Split(REMOTE_CONNECTOR_ID, new RemoteTransactionHandle(), new RemoteSplit(new Location(splitLocation), remoteSourceTaskId));
}

private String getCteIdFromSource(PlanNode source)
{
// Traverse the plan node tree to find a TableWriterNode with TemporaryTableInfo
return PlanNodeSearcher.searchFrom(source)
.where(planNode -> planNode instanceof TableFinishNode)
.findFirst()
.map(planNode -> ((TableFinishNode) planNode).getTemporaryTableInfo().orElseThrow(
() -> new IllegalStateException("TableFinishNode has no TemporaryTableInfo")))
.map(TemporaryTableInfo::getCteId)
.orElseThrow(() -> new IllegalStateException("TemporaryTableInfo has no CTE ID"));
}

public boolean isCTEWriterStage()
{
// Use PlanChecker or traversal utility to identify if the plan involves CTE writing
List<TableWriterNode> writerNodes = PlanNodeSearcher.searchFrom(planFragment.getRoot())
.where(planNode -> (planNode instanceof TableFinishNode) && ((TableFinishNode) planNode).getTemporaryTableInfo().isPresent())
.findAll();
return writerNodes.size() > 0;
}

public String getCTEWriterId()
{
// Validate that this is a CTE writer stage and return the associated CTE ID
if (!isCTEWriterStage()) {
throw new IllegalStateException("This stage is not a CTE writer stage");
}
return getCteIdFromSource(planFragment.getRoot());
}

public boolean requiresMaterializedCTE()
{
// Search for TableScanNodes and check if they reference TemporaryTableInfo
return PlanNodeSearcher.searchFrom(planFragment.getRoot())
.where(planNode -> planNode instanceof TableScanNode)
.findAll().stream()
.anyMatch(planNode -> ((TableScanNode) planNode).getTemporaryTableInfo().isPresent());
}

public List<String> getRequiredCTEList()
{
// Collect all CTE IDs referenced by TableScanNodes with TemporaryTableInfo
return PlanNodeSearcher.searchFrom(planFragment.getRoot())
.where(planNode -> planNode instanceof TableScanNode)
.findAll().stream()
.map(planNode -> ((TableScanNode) planNode).getTemporaryTableInfo()
.orElseThrow(() -> new IllegalStateException("TableScanNode has no TemporaryTableInfo")))
.map(TemporaryTableInfo::getCteId)
.collect(Collectors.toList());
}

private void updateTaskStatus(TaskId taskId, TaskStatus taskStatus)
{
StageExecutionState stageExecutionState = getState();
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.facebook.presto.execution.scheduler;

import com.google.common.util.concurrent.ListenableFuture;
import com.google.common.util.concurrent.SettableFuture;

import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;

public class CTEMaterializationTracker
{
private final Map<String, SettableFuture<Void>> materializationFutures = new ConcurrentHashMap<>();

private final Map<String, Boolean> materializedCtes = new ConcurrentHashMap<>();

public ListenableFuture<Void> getFutureForCTE(String cteName)
{
if (materializationFutures.containsKey(cteName)) {
if (!materializationFutures.get(cteName).isCancelled()) {
return materializationFutures.get(cteName);
}
}
materializationFutures.put(cteName, SettableFuture.create());
return materializationFutures.get(cteName);
}

public void markCTEAsMaterialized(String cteName)
{
materializedCtes.put(cteName, true);
SettableFuture<Void> future = materializationFutures.get(cteName);
if (!future.isCancelled() && future != null) {
future.set(null); // Notify all listeners
}
}

public void markAllCTEsMaterialized()
{
materializationFutures.forEach((k, v) -> {
markCTEAsMaterialized(k);
});
}

public boolean hasBeenMaterialized(String cteName)
{
return materializedCtes.containsKey(cteName);
}
}

Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,8 @@ public class FixedSourcePartitionedScheduler

private final Queue<Integer> tasksToRecover = new ConcurrentLinkedQueue<>();

private final CTEMaterializationTracker cteMaterializationTracker;

@GuardedBy("this")
private boolean closed;

Expand All @@ -87,13 +89,15 @@ public FixedSourcePartitionedScheduler(
int splitBatchSize,
OptionalInt concurrentLifespansPerTask,
NodeSelector nodeSelector,
List<ConnectorPartitionHandle> partitionHandles)
List<ConnectorPartitionHandle> partitionHandles,
CTEMaterializationTracker cteMaterializationTracker)
{
requireNonNull(stage, "stage is null");
requireNonNull(splitSources, "splitSources is null");
requireNonNull(bucketNodeMap, "bucketNodeMap is null");
checkArgument(!requireNonNull(nodes, "nodes is null").isEmpty(), "nodes is empty");
requireNonNull(partitionHandles, "partitionHandles is null");
this.cteMaterializationTracker = cteMaterializationTracker;

this.stage = stage;
this.nodes = ImmutableList.copyOf(nodes);
Expand Down Expand Up @@ -179,10 +183,35 @@ public ScheduleResult schedule()
{
// schedule a task on every node in the distribution
List<RemoteTask> newTasks = ImmutableList.of();
List<ListenableFuture<?>> blocked = new ArrayList<>();

// CTE Materialization Check
if (stage.requiresMaterializedCTE()) {
List<String> requiredCTEIds = stage.getRequiredCTEList(); // Ensure this method exists and returns a list of required CTE IDs as strings
for (String cteId : requiredCTEIds) {
if (!cteMaterializationTracker.hasBeenMaterialized(cteId)) {
// Add CTE materialization future to the blocked list
ListenableFuture<Void> materializationFuture = cteMaterializationTracker.getFutureForCTE(cteId);
blocked.add(materializationFuture);
}
}
// If any CTE is not materialized, return a blocked ScheduleResult
if (!blocked.isEmpty()) {
return ScheduleResult.blocked(
false, // true if all required CTEs are blocked
newTasks,
whenAnyComplete(blocked), // Wait for any CTE materialization to complete
BlockedReason.WAITING_FOR_CTE_MATERIALIZATION,
0);
}
}
blocked = new ArrayList<>();
newTasks = ImmutableList.of();

// schedule a task on every node in the distribution
if (!scheduledTasks) {
newTasks = Streams.mapWithIndex(
nodes.stream(),
(node, id) -> stage.scheduleTask(node, toIntExact(id)))
nodes.stream(), (node, id) -> stage.scheduleTask(node, toIntExact(id)))
.filter(Optional::isPresent)
.map(Optional::get)
.collect(toImmutableList());
Expand All @@ -193,7 +222,6 @@ public ScheduleResult schedule()
}

boolean allBlocked = true;
List<ListenableFuture<?>> blocked = new ArrayList<>();
BlockedReason blockedReason = BlockedReason.NO_ACTIVE_DRIVER_GROUP;

if (groupedLifespanScheduler.isPresent()) {
Expand Down Expand Up @@ -260,7 +288,7 @@ public ScheduleResult schedule()
}
}

if (allBlocked) {
if (allBlocked && !blocked.isEmpty()) {
return ScheduleResult.blocked(sourceSchedulers.isEmpty(), newTasks, whenAnyComplete(blocked), blockedReason, splitsScheduled);
}
else {
Expand All @@ -277,15 +305,17 @@ public void recover(TaskId taskId)
public synchronized void close()
{
closed = true;
for (SourceScheduler sourceScheduler : sourceSchedulers) {
try {
sourceScheduler.close();
}
catch (Throwable t) {
log.warn(t, "Error closing split source");
if (scheduledTasks) {
for (SourceScheduler sourceScheduler : sourceSchedulers) {
try {
sourceScheduler.close();
}
catch (Throwable t) {
log.warn(t, "Error closing split source");
}
}
sourceSchedulers.clear();
}
sourceSchedulers.clear();
}

public static class BucketedSplitPlacementPolicy
Expand Down
Loading

0 comments on commit 64a3b17

Please sign in to comment.