Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ML] prefer least allocated model when a new node is added to the cluster #77756

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -40,12 +40,13 @@
import org.elasticsearch.xpack.ml.job.NodeLoadDetector;

import java.util.Collections;
import java.util.List;
import java.util.Comparator;
import java.util.Locale;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.TreeMap;
import java.util.function.Function;
import java.util.stream.Collectors;

public class TrainedModelAllocationClusterService implements ClusterStateListener {
Expand Down Expand Up @@ -245,8 +246,7 @@ ClusterState createModelAllocation(ClusterState currentState, StartTrainedModelD
Set<String> shuttingDownNodes = nodesShuttingDown(currentState);
Map<String, String> nodeToReason = new TreeMap<>();
for (DiscoveryNode node : currentState.getNodes().getAllNodes()) {
if (StartTrainedModelDeploymentAction.TaskParams.mayAllocateToNode(node)
&& shuttingDownNodes.contains(node.getId()) == false) {
if (StartTrainedModelDeploymentAction.TaskParams.mayAllocateToNode(node) && shuttingDownNodes.contains(node.getId()) == false) {
Optional<String> maybeError = nodeHasCapacity(currentState, params, node);
if (maybeError.isPresent()) {
nodeToReason.put(node.getName(), maybeError.get());
Expand Down Expand Up @@ -289,16 +289,8 @@ static ClusterState updateModelRoutingTable(ClusterState currentState, UpdateTra
logger.trace(
() -> new ParameterizedMessage("[{}] [{}] current metadata before update {}", modelId, nodeId, Strings.toString(metadata))
);
Set<String> shuttingDownNodes = nodesShuttingDown(currentState);
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is old, unused code, decided to clean it up. Unrelated to the change in this PR.

List<DiscoveryNode> allocatableNodes = currentState.nodes()
.getAllNodes()
.stream()
.filter(
d -> StartTrainedModelDeploymentAction.TaskParams.mayAllocateToNode(d) && shuttingDownNodes.contains(d.getId()) == false
)
.collect(Collectors.toList());
final TrainedModelAllocation existingAllocation = metadata.getModelAllocation(modelId);
final TrainedModelAllocationMetadata.Builder builder = TrainedModelAllocationMetadata.builder(currentState);
final TrainedModelAllocationMetadata.Builder builder = TrainedModelAllocationMetadata.builder(currentState);
// If state is stopped, this indicates the node process is closed, remove the node from the allocation
if (request.getRoutingState().getState().equals(RoutingState.STOPPED)) {
if (existingAllocation == null || existingAllocation.isRoutedToNode(nodeId) == false) {
Expand All @@ -313,20 +305,20 @@ static ClusterState updateModelRoutingTable(ClusterState currentState, UpdateTra
}
// If we are stopping, don't update anything
if (existingAllocation.getAllocationState().equals(AllocationState.STOPPING)) {
logger.debug(() -> new ParameterizedMessage(
"[{}] requested update from node [{}] to update route state to [{}]",
modelId,
nodeId,
request.getRoutingState()
));
logger.debug(
() -> new ParameterizedMessage(
"[{}] requested update from node [{}] to update route state to [{}]",
modelId,
nodeId,
request.getRoutingState()
)
);
return currentState;
}
if (existingAllocation.isRoutedToNode(nodeId) == false) {
throw new ResourceNotFoundException("allocation for model with id [{}]] is not routed to node [{}]", modelId, nodeId);
}
builder.getAllocation(modelId)
.updateExistingRoutingEntry(nodeId, request.getRoutingState())
.calculateAndSetAllocationState();
builder.getAllocation(modelId).updateExistingRoutingEntry(nodeId, request.getRoutingState()).calculateAndSetAllocationState();

return update(currentState, builder);
}
Expand All @@ -342,7 +334,7 @@ static ClusterState removeAllocation(ClusterState currentState, String modelId)
static ClusterState removeAllAllocations(ClusterState currentState) {
if (TrainedModelAllocationMetadata.fromState(currentState).modelAllocations().isEmpty()) {
return currentState;
};
}
return ClusterState.builder(currentState)
.metadata(
Metadata.builder(currentState.metadata())
Expand All @@ -356,64 +348,62 @@ ClusterState addRemoveAllocationNodes(ClusterState currentState) {
final TrainedModelAllocationMetadata previousState = TrainedModelAllocationMetadata.fromState(currentState);
final TrainedModelAllocationMetadata.Builder builder = TrainedModelAllocationMetadata.builder(currentState);
Set<String> shuttingDownNodes = nodesShuttingDown(currentState);
Set<String> currentNotShuttingDownNodes = currentState.getNodes()
Map<String, DiscoveryNode> currentEligibleNodes = currentState.getNodes()
.getAllNodes()
.stream()
.map(DiscoveryNode::getId)
.filter(id -> shuttingDownNodes.contains(id) == false)
.collect(Collectors.toSet());
// TODO: make more efficient, right now this is O(nm) where n = sizeof(models) and m = sizeof(nodes)
// It could probably be O(max(n, m))
// Add nodes and keep track of currently routed nodes
// Should we indicate a partial allocation somehow if some nodes don't have space?
for (Map.Entry<String, TrainedModelAllocation> modelAllocationEntry : previousState.modelAllocations().entrySet()) {
// Don't bother adding/removing nodes if this allocation is stopping
if (modelAllocationEntry.getValue().getAllocationState().equals(AllocationState.STOPPING)) {
continue;
}
final String modelId = modelAllocationEntry.getKey();
Map<String, String> nodeToReason = new TreeMap<>();
for (DiscoveryNode node : currentState.getNodes()) {
// Only add the route if the node is NOT shutting down, this would be a weird case of the node
// just being added to the cluster and immediately shutting down...
if (shuttingDownNodes.contains(node.getId()) == false
&& StartTrainedModelDeploymentAction.TaskParams.mayAllocateToNode(node)
&& modelAllocationEntry.getValue().isRoutedToNode(node.getId()) == false) {
Optional<String> failure = nodeHasCapacity(currentState, modelAllocationEntry.getValue().getTaskParams(), node);
if (failure.isPresent()) {
nodeToReason.put(node.getName(), failure.get());
} else {
builder.getAllocation(modelId).addNewRoutingEntry(node.getId());
// TODO: Change when we update `mayAllocateToNode`
.filter(node -> shuttingDownNodes.contains(node.getId()) == false
&& StartTrainedModelDeploymentAction.TaskParams.mayAllocateToNode(node))
.collect(Collectors.toMap(DiscoveryNode::getId, Function.identity()));
// TODO: make more efficient, we iterate every entry, sorting by nodes routed (fewest to most)
previousState.modelAllocations()
.entrySet()
.stream()
.filter(entry -> entry.getValue().getAllocationState().equals(AllocationState.STOPPING) == false)
.sorted(Comparator.comparing(e -> e.getValue().getNodeRoutingTable().size()))
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is obviously not optimal, but I am not 100% sure how else to do it.

One issue here (besides performance), is that we always use the counts based on the original cluster state. Meaning, as models are allocated, our iteration order isn't updated in the same pass.

This is probably acceptable.

Another option is that we maybe iterate on nodes in the outer loop? This way all allocations attempt the same node and then we move on? The only downside here is that the order of inner iteration may change all the time.

Another option is attempting this same thing with a priority queue that updates on each iteration, popping off all the least allocations and rebuilding each time, but this seems pretty expensive to me as the whole heap would be rebuit on every pass on a new node...

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A low hanging fruit here is to apply the filter of StartTrainedModelDeploymentAction.TaskParams.mayAllocateToNode at the point where we collect currentNotShuttingDownNodes. So that would become currentEligibleNodes for instance and it'd be non-shutting down, ML nodes of an appropriate version. This would significantly reduce one of the parameters here for most clusters.

Not sure I can think of something else from a first look. But if we make this optimization, then we can run a few tests to see how long it takes to run this with realistic parameter values. It might not be a problem at all.

.forEach(modelAllocationEntry -> {
final String modelId = modelAllocationEntry.getKey();
Map<String, String> nodeToReason = new TreeMap<>();
for (DiscoveryNode node : currentEligibleNodes.values()) {
if (modelAllocationEntry.getValue().isRoutedToNode(node.getId()) == false) {
Optional<String> failure = builder.isChanged() ?
// We use the builder only if we have changed, there is no point in creating a new object if we haven't changed
nodeHasCapacity(currentState, builder, modelAllocationEntry.getValue().getTaskParams(), node) :
nodeHasCapacity(currentState, modelAllocationEntry.getValue().getTaskParams(), node);
if (failure.isPresent()) {
nodeToReason.put(node.getName(), failure.get());
} else {
builder.getAllocation(modelId).addNewRoutingEntry(node.getId());
}
}
}
}
if (nodeToReason.isEmpty() == false) {
builder.getAllocation(modelId)
.setReason(
nodeToReason.entrySet()
.stream()
.map(
entry -> String.format(
Locale.ROOT,
"Not allocating on node [%s]. Reason: %s",
entry.getKey(),
entry.getValue()
if (nodeToReason.isEmpty() == false) {
builder.getAllocation(modelId)
.setReason(
nodeToReason.entrySet()
.stream()
.map(
entry -> String.format(
Locale.ROOT,
"Not allocating on node [%s]. Reason: %s",
entry.getKey(),
entry.getValue()
)
)
)
.collect(Collectors.joining("|"))
);
} else {
builder.getAllocation(modelId).clearReason();
}
for (String nodeId : modelAllocationEntry.getValue().getNodeRoutingTable().keySet()) {
if (currentNotShuttingDownNodes.contains(nodeId) == false) {
builder.getAllocation(modelId).removeRoutingEntry(nodeId);
.collect(Collectors.joining("|"))
);
} else {
builder.getAllocation(modelId).clearReason();
}
}
// It may be we moved from STARTED to PARTIALLY_STARTED with the addition of new nodes
// Or moved from PARTIALLY_STARTED to STARTED if a node was removed
builder.getAllocation(modelId).calculateAndSetAllocationState();
}
for (String nodeId : modelAllocationEntry.getValue().getNodeRoutingTable().keySet()) {
if (currentEligibleNodes.containsKey(nodeId) == false) {
builder.getAllocation(modelId).removeRoutingEntry(nodeId);
}
}
// It may be we moved from STARTED to PARTIALLY_STARTED with the addition of new nodes
// Or moved from PARTIALLY_STARTED to STARTED if a node was removed
builder.getAllocation(modelId).calculateAndSetAllocationState();
});
return update(currentState, builder);
}

Expand Down Expand Up @@ -448,8 +438,33 @@ static boolean shouldAllocateModels(final ClusterChangedEvent event) {

Optional<String> nodeHasCapacity(ClusterState state, StartTrainedModelDeploymentAction.TaskParams params, DiscoveryNode node) {
NodeLoad load = nodeLoadDetector.detectNodeLoad(state, true, node, Integer.MAX_VALUE, maxMemoryPercentage, useAuto);
return handleNodeLoad(load, node.getId(), params);
}

/**
* Gather current node capacity taking the passed allocation metadata into account instead of the one stored in cluster state.
*/
Optional<String> nodeHasCapacity(
ClusterState state,
TrainedModelAllocationMetadata.Builder builder,
StartTrainedModelDeploymentAction.TaskParams params,
DiscoveryNode node
) {
NodeLoad load = nodeLoadDetector.detectNodeLoad(
state,
builder.build(),
true,
node,
Integer.MAX_VALUE,
maxMemoryPercentage,
useAuto
);
return handleNodeLoad(load, node.getId(), params);
}

Optional<String> handleNodeLoad(NodeLoad load, String nodeId, StartTrainedModelDeploymentAction.TaskParams params) {
if (Strings.isNullOrEmpty(load.getError()) == false) {
logger.warn("[{}] failed to calculate current node load with error [{}]", params.getModelId(), node.getId());
logger.warn("[{}] failed to calculate current node load with error [{}]", params.getModelId(), nodeId);
return Optional.of(load.getError());
}
if (load.getFreeMemory() < params.estimateMemoryUsageBytes()) {
Expand All @@ -464,8 +479,7 @@ Optional<String> nodeHasCapacity(ClusterState state, StartTrainedModelDeployment
load.getAssignedJobMemory(),
ByteSizeValue.ofBytes(load.getAssignedJobMemory()).toString(),
params.estimateMemoryUsageBytes(),
ByteSizeValue.ofBytes(params.estimateMemoryUsageBytes()).toString()
}
ByteSizeValue.ofBytes(params.estimateMemoryUsageBytes()).toString() }
)
);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,24 @@ public NodeLoad detectNodeLoad(ClusterState clusterState,
int dynamicMaxOpenJobs,
int maxMachineMemoryPercent,
boolean useAutoMachineMemoryCalculation) {
return detectNodeLoad(
clusterState,
TrainedModelAllocationMetadata.fromState(clusterState),
allNodesHaveDynamicMaxWorkers,
node,
dynamicMaxOpenJobs,
maxMachineMemoryPercent,
useAutoMachineMemoryCalculation
);
}

public NodeLoad detectNodeLoad(ClusterState clusterState,
TrainedModelAllocationMetadata allocationMetadata,
boolean allNodesHaveDynamicMaxWorkers,
DiscoveryNode node,
int dynamicMaxOpenJobs,
int maxMachineMemoryPercent,
boolean useAutoMachineMemoryCalculation) {
PersistentTasksCustomMetadata persistentTasks = clusterState.getMetadata().custom(PersistentTasksCustomMetadata.TYPE);
Map<String, String> nodeAttributes = node.getAttributes();
List<String> errors = new ArrayList<>();
Expand Down Expand Up @@ -80,7 +98,7 @@ public NodeLoad detectNodeLoad(ClusterState clusterState,
return nodeLoad.setError(Strings.collectionToCommaDelimitedString(errors)).build();
}
updateLoadGivenTasks(nodeLoad, persistentTasks);
updateLoadGivenModelAllocations(nodeLoad, TrainedModelAllocationMetadata.fromState(clusterState));
updateLoadGivenModelAllocations(nodeLoad, allocationMetadata);
return nodeLoad.build();
}

Expand Down
Loading