Skip to content

Commit

Permalink
Scheduling changes
Browse files Browse the repository at this point in the history
  • Loading branch information
jaystarshot committed Nov 21, 2024
1 parent 8303013 commit 5a5165e
Show file tree
Hide file tree
Showing 16 changed files with 249 additions and 57 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -465,6 +465,7 @@ public void testPersistentCteWithVarbinary()
QueryRunner queryRunner = getQueryRunner();
verifyResults(queryRunner, testQuery, ImmutableList.of(generateMaterializedCTEInformation("dataset", 1, false, true)));
}

@Test
public void testComplexRefinedCtesOutsideScope()
{
Expand Down Expand Up @@ -706,7 +707,7 @@ public void testComplexChainOfDependentAndNestedPersistentCtes()
generateMaterializedCTEInformation("cte6", 1, false, true)));
}

@Test
@Test(enabled = false)
public void testComplexQuery1()
{
String testQuery = "WITH customer_nation AS (" +
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,16 @@
import com.facebook.presto.server.remotetask.HttpRemoteTask;
import com.facebook.presto.spi.PrestoException;
import com.facebook.presto.spi.plan.PlanFragmentId;
import com.facebook.presto.spi.plan.PlanNode;
import com.facebook.presto.spi.plan.PlanNodeId;
import com.facebook.presto.spi.plan.TableScanNode;
import com.facebook.presto.spi.plan.TemporaryTableInfo;
import com.facebook.presto.split.RemoteSplit;
import com.facebook.presto.sql.planner.PlanFragment;
import com.facebook.presto.sql.planner.optimizations.PlanNodeSearcher;
import com.facebook.presto.sql.planner.plan.RemoteSourceNode;
import com.facebook.presto.sql.planner.plan.TableFinishNode;
import com.facebook.presto.sql.planner.plan.TableWriterNode;
import com.google.common.collect.HashMultimap;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableMultimap;
Expand Down Expand Up @@ -60,6 +66,7 @@
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.Consumer;
import java.util.stream.Collectors;

import static com.facebook.presto.SystemSessionProperties.getMaxFailedTaskPercentage;
import static com.facebook.presto.failureDetector.FailureDetector.State.GONE;
Expand Down Expand Up @@ -110,8 +117,6 @@ public final class SqlStageExecution

private final Map<InternalNode, Set<RemoteTask>> tasks = new ConcurrentHashMap<>();

private final Map<TaskId, RemoteTask> taskMap = new ConcurrentHashMap<>();

@GuardedBy("this")
private final AtomicInteger nextTaskId = new AtomicInteger();

Expand Down Expand Up @@ -515,9 +520,6 @@ public synchronized Set<RemoteTask> scheduleSplits(InternalNode node, Multimap<P

private synchronized RemoteTask scheduleTask(InternalNode node, TaskId taskId, Multimap<PlanNodeId, Split> sourceSplits)
{
if (taskMap.containsKey(taskId)) {
return taskMap.get(taskId);
}
checkArgument(!allTasks.contains(taskId), "A task with id %s already exists", taskId);

ImmutableMultimap.Builder<PlanNodeId, Split> initialSplits = ImmutableMultimap.builder();
Expand Down Expand Up @@ -563,7 +565,6 @@ private synchronized RemoteTask scheduleTask(InternalNode node, TaskId taskId, M
// stage finished while we were scheduling this task
task.abort();
}
taskMap.put(taskId, task);
return task;
}

Expand Down Expand Up @@ -600,6 +601,57 @@ private static Split createRemoteSplitFor(TaskId taskId, URI remoteSourceTaskLoc
return new Split(REMOTE_CONNECTOR_ID, new RemoteTransactionHandle(), new RemoteSplit(new Location(splitLocation), remoteSourceTaskId));
}

private String getCteIdFromSource(PlanNode source)
{
// Traverse the plan node tree to find a TableWriterNode with TemporaryTableInfo
return PlanNodeSearcher.searchFrom(source)
.where(planNode -> planNode instanceof TableFinishNode)
.findFirst()
.map(planNode -> ((TableFinishNode) planNode).getTemporaryTableInfo().orElseThrow(
() -> new IllegalStateException("TableFinishNode has no TemporaryTableInfo")))
.map(TemporaryTableInfo::getCteId)
.orElseThrow(() -> new IllegalStateException("TemporaryTableInfo has no CTE ID"));
}

public boolean isCTEWriterStage()
{
// Use PlanChecker or traversal utility to identify if the plan involves CTE writing
List<TableWriterNode> writerNodes = PlanNodeSearcher.searchFrom(planFragment.getRoot())
.where(planNode -> (planNode instanceof TableFinishNode) && ((TableFinishNode) planNode).getTemporaryTableInfo().isPresent())
.findAll();
return writerNodes.size() > 0;
}

public String getCTEWriterId()
{
// Validate that this is a CTE writer stage and return the associated CTE ID
if (!isCTEWriterStage()) {
throw new IllegalStateException("This stage is not a CTE writer stage");
}
return getCteIdFromSource(planFragment.getRoot());
}

public boolean requiresMaterializedCTE()
{
// Search for TableScanNodes and check if they reference TemporaryTableInfo
return PlanNodeSearcher.searchFrom(planFragment.getRoot())
.where(planNode -> planNode instanceof TableScanNode)
.findAll().stream()
.anyMatch(planNode -> ((TableScanNode) planNode).getTemporaryTableInfo().isPresent());
}

public List<String> getRequiredCTEList()
{
// Collect all CTE IDs referenced by TableScanNodes with TemporaryTableInfo
return PlanNodeSearcher.searchFrom(planFragment.getRoot())
.where(planNode -> planNode instanceof TableScanNode)
.findAll().stream()
.map(planNode -> ((TableScanNode) planNode).getTemporaryTableInfo()
.orElseThrow(() -> new IllegalStateException("TableScanNode has no TemporaryTableInfo")))
.map(TemporaryTableInfo::getCteId)
.collect(Collectors.toList());
}

private void updateTaskStatus(TaskId taskId, TaskStatus taskStatus)
{
StageExecutionState stageExecutionState = getState();
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.facebook.presto.execution.scheduler;

import com.google.common.util.concurrent.ListenableFuture;
import com.google.common.util.concurrent.SettableFuture;

import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;

public class CTEMaterializationTracker
{
private final Map<String, SettableFuture<Void>> materializationFutures = new ConcurrentHashMap<>();

private final Map<String, Boolean> materializedCtes = new ConcurrentHashMap<>();

public ListenableFuture<Void> getFutureForCTE(String cteName)
{
if (materializationFutures.containsKey(cteName)) {
if (!materializationFutures.get(cteName).isCancelled()) {
return materializationFutures.get(cteName);
}
}
materializationFutures.put(cteName, SettableFuture.create());
return materializationFutures.get(cteName);
}

public void markCTEAsMaterialized(String cteName)
{
materializedCtes.put(cteName, true);
SettableFuture<Void> future = materializationFutures.get(cteName);
if (!future.isCancelled() && future != null) {
future.set(null); // Notify all listeners
}
}

public void markAllCTEsMaterialized()
{
materializationFutures.forEach((k, v) -> {
markCTEAsMaterialized(k);
});
}

public boolean hasBeenMaterialized(String cteName)
{
return materializedCtes.containsKey(cteName);
}
}

Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,8 @@ public class FixedSourcePartitionedScheduler

private final Queue<Integer> tasksToRecover = new ConcurrentLinkedQueue<>();

private final CTEMaterializationTracker cteMaterializationTracker;

@GuardedBy("this")
private boolean closed;

Expand All @@ -87,13 +89,15 @@ public FixedSourcePartitionedScheduler(
int splitBatchSize,
OptionalInt concurrentLifespansPerTask,
NodeSelector nodeSelector,
List<ConnectorPartitionHandle> partitionHandles)
List<ConnectorPartitionHandle> partitionHandles,
CTEMaterializationTracker cteMaterializationTracker)
{
requireNonNull(stage, "stage is null");
requireNonNull(splitSources, "splitSources is null");
requireNonNull(bucketNodeMap, "bucketNodeMap is null");
checkArgument(!requireNonNull(nodes, "nodes is null").isEmpty(), "nodes is empty");
requireNonNull(partitionHandles, "partitionHandles is null");
this.cteMaterializationTracker = cteMaterializationTracker;

this.stage = stage;
this.nodes = ImmutableList.copyOf(nodes);
Expand Down Expand Up @@ -179,10 +183,35 @@ public ScheduleResult schedule()
{
// schedule a task on every node in the distribution
List<RemoteTask> newTasks = ImmutableList.of();
List<ListenableFuture<?>> blocked = new ArrayList<>();

// CTE Materialization Check
if (stage.requiresMaterializedCTE()) {
List<String> requiredCTEIds = stage.getRequiredCTEList(); // Ensure this method exists and returns a list of required CTE IDs as strings
for (String cteId : requiredCTEIds) {
if (!cteMaterializationTracker.hasBeenMaterialized(cteId)) {
// Add CTE materialization future to the blocked list
ListenableFuture<Void> materializationFuture = cteMaterializationTracker.getFutureForCTE(cteId);
blocked.add(materializationFuture);
}
}
// If any CTE is not materialized, return a blocked ScheduleResult
if (!blocked.isEmpty()) {
return ScheduleResult.blocked(
false, // true if all required CTEs are blocked
newTasks,
whenAnyComplete(blocked), // Wait for any CTE materialization to complete
BlockedReason.WAITING_FOR_CTE_MATERIALIZATION,
0);
}
}
blocked = new ArrayList<>();
newTasks = ImmutableList.of();

// schedule a task on every node in the distribution
if (!scheduledTasks) {
newTasks = Streams.mapWithIndex(
nodes.stream(),
(node, id) -> stage.scheduleTask(node, toIntExact(id)))
nodes.stream(), (node, id) -> stage.scheduleTask(node, toIntExact(id)))
.filter(Optional::isPresent)
.map(Optional::get)
.collect(toImmutableList());
Expand All @@ -193,7 +222,6 @@ public ScheduleResult schedule()
}

boolean allBlocked = true;
List<ListenableFuture<?>> blocked = new ArrayList<>();
BlockedReason blockedReason = BlockedReason.NO_ACTIVE_DRIVER_GROUP;

if (groupedLifespanScheduler.isPresent()) {
Expand Down Expand Up @@ -260,7 +288,7 @@ public ScheduleResult schedule()
}
}

if (allBlocked) {
if (allBlocked && !blocked.isEmpty()) {
return ScheduleResult.blocked(sourceSchedulers.isEmpty(), newTasks, whenAnyComplete(blocked), blockedReason, splitsScheduled);
}
else {
Expand All @@ -277,15 +305,17 @@ public void recover(TaskId taskId)
public synchronized void close()
{
closed = true;
for (SourceScheduler sourceScheduler : sourceSchedulers) {
try {
sourceScheduler.close();
}
catch (Throwable t) {
log.warn(t, "Error closing split source");
if (scheduledTasks) {
for (SourceScheduler sourceScheduler : sourceSchedulers) {
try {
sourceScheduler.close();
}
catch (Throwable t) {
log.warn(t, "Error closing split source");
}
}
sourceSchedulers.clear();
}
sourceSchedulers.clear();
}

public static class BucketedSplitPlacementPolicy
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,15 +57,19 @@ public enum BlockedReason
* grouped execution where there are multiple lifespans per task).
*/
MIXED_SPLIT_QUEUES_FULL_AND_WAITING_FOR_SOURCE,

WAITING_FOR_CTE_MATERIALIZATION,
/**/;

public BlockedReason combineWith(BlockedReason other)
{
switch (this) {
case WRITER_SCALING:
throw new IllegalArgumentException("cannot be combined");
case WAITING_FOR_CTE_MATERIALIZATION:
return other;
case NO_ACTIVE_DRIVER_GROUP:
return other;
case WRITER_SCALING:
throw new IllegalArgumentException("cannot be combined");
case SPLIT_QUEUES_FULL:
return other == SPLIT_QUEUES_FULL || other == NO_ACTIVE_DRIVER_GROUP ? SPLIT_QUEUES_FULL : MIXED_SPLIT_QUEUES_FULL_AND_WAITING_FOR_SOURCE;
case WAITING_FOR_SOURCE:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,8 @@ public SectionExecution createSectionExecutions(
boolean summarizeTaskInfo,
RemoteTaskFactory remoteTaskFactory,
SplitSourceFactory splitSourceFactory,
int attemptId)
int attemptId,
CTEMaterializationTracker cteMaterializationTracker)
{
// Only fetch a distribution once per section to ensure all stages see the same machine assignments
Map<PartitioningHandle, NodePartitionMap> partitioningCache = new HashMap<>();
Expand All @@ -184,7 +185,8 @@ public SectionExecution createSectionExecutions(
summarizeTaskInfo,
remoteTaskFactory,
splitSourceFactory,
attemptId);
attemptId,
cteMaterializationTracker);
StageExecutionAndScheduler rootStage = getLast(sectionStages);
rootStage.getStageExecution().setOutputBuffers(outputBuffers);
return new SectionExecution(rootStage, sectionStages);
Expand All @@ -203,7 +205,8 @@ private List<StageExecutionAndScheduler> createStreamingLinkedStageExecutions(
boolean summarizeTaskInfo,
RemoteTaskFactory remoteTaskFactory,
SplitSourceFactory splitSourceFactory,
int attemptId)
int attemptId,
CTEMaterializationTracker cteMaterializationTracker)
{
ImmutableList.Builder<StageExecutionAndScheduler> stageExecutionAndSchedulers = ImmutableList.builder();

Expand Down Expand Up @@ -238,7 +241,8 @@ private List<StageExecutionAndScheduler> createStreamingLinkedStageExecutions(
summarizeTaskInfo,
remoteTaskFactory,
splitSourceFactory,
attemptId);
attemptId,
cteMaterializationTracker);
stageExecutionAndSchedulers.addAll(subTree);
childStagesBuilder.add(getLast(subTree).getStageExecution());
}
Expand All @@ -260,7 +264,8 @@ private List<StageExecutionAndScheduler> createStreamingLinkedStageExecutions(
stageExecution,
partitioningHandle,
tableWriteInfo,
childStageExecutions);
childStageExecutions,
cteMaterializationTracker);
stageExecutionAndSchedulers.add(new StageExecutionAndScheduler(
stageExecution,
stageLinkage,
Expand All @@ -279,7 +284,8 @@ private StageScheduler createStageScheduler(
SqlStageExecution stageExecution,
PartitioningHandle partitioningHandle,
TableWriteInfo tableWriteInfo,
Set<SqlStageExecution> childStageExecutions)
Set<SqlStageExecution> childStageExecutions,
CTEMaterializationTracker cteMaterializationTracker)
{
Map<PlanNodeId, SplitSource> splitSources = splitSourceFactory.createSplitSources(plan.getFragment(), session, tableWriteInfo);
int maxTasksPerStage = getMaxTasksPerStage(session);
Expand Down Expand Up @@ -383,7 +389,8 @@ else if (partitioningHandle.equals(SCALED_WRITER_DISTRIBUTION)) {
splitBatchSize,
getConcurrentLifespansPerNode(session),
nodeScheduler.createNodeSelector(session, connectorId, nodePredicate),
connectorPartitionHandles);
connectorPartitionHandles,
cteMaterializationTracker);
if (plan.getFragment().getStageExecutionDescriptor().isRecoverableGroupedExecution()) {
stageExecution.registerStageTaskRecoveryCallback(taskId -> {
checkArgument(taskId.getStageExecutionId().getStageId().equals(stageId), "The task did not execute this stage");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,7 @@ public synchronized void rewindLifespan(Lifespan lifespan, ConnectorPartitionHan
@Override
public synchronized ScheduleResult schedule()
{
dropListenersFromWhenFinishedOrNewLifespansAdded();
// dropListenersFromWhenFinishedOrNewLifespansAdded();

int overallSplitAssignmentCount = 0;
ImmutableSet.Builder<RemoteTask> overallNewTasks = ImmutableSet.builder();
Expand Down
Loading

0 comments on commit 5a5165e

Please sign in to comment.