-
Notifications
You must be signed in to change notification settings - Fork 25.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[ML] prefer least allocated model when a new node is added to the cluster #77756
Changes from all commits
472704b
6406ebb
9a307e6
3bd697d
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -40,12 +40,13 @@ | |
import org.elasticsearch.xpack.ml.job.NodeLoadDetector; | ||
|
||
import java.util.Collections; | ||
import java.util.List; | ||
import java.util.Comparator; | ||
import java.util.Locale; | ||
import java.util.Map; | ||
import java.util.Optional; | ||
import java.util.Set; | ||
import java.util.TreeMap; | ||
import java.util.function.Function; | ||
import java.util.stream.Collectors; | ||
|
||
public class TrainedModelAllocationClusterService implements ClusterStateListener { | ||
|
@@ -245,8 +246,7 @@ ClusterState createModelAllocation(ClusterState currentState, StartTrainedModelD | |
Set<String> shuttingDownNodes = nodesShuttingDown(currentState); | ||
Map<String, String> nodeToReason = new TreeMap<>(); | ||
for (DiscoveryNode node : currentState.getNodes().getAllNodes()) { | ||
if (StartTrainedModelDeploymentAction.TaskParams.mayAllocateToNode(node) | ||
&& shuttingDownNodes.contains(node.getId()) == false) { | ||
if (StartTrainedModelDeploymentAction.TaskParams.mayAllocateToNode(node) && shuttingDownNodes.contains(node.getId()) == false) { | ||
Optional<String> maybeError = nodeHasCapacity(currentState, params, node); | ||
if (maybeError.isPresent()) { | ||
nodeToReason.put(node.getName(), maybeError.get()); | ||
|
@@ -289,16 +289,8 @@ static ClusterState updateModelRoutingTable(ClusterState currentState, UpdateTra | |
logger.trace( | ||
() -> new ParameterizedMessage("[{}] [{}] current metadata before update {}", modelId, nodeId, Strings.toString(metadata)) | ||
); | ||
Set<String> shuttingDownNodes = nodesShuttingDown(currentState); | ||
List<DiscoveryNode> allocatableNodes = currentState.nodes() | ||
.getAllNodes() | ||
.stream() | ||
.filter( | ||
d -> StartTrainedModelDeploymentAction.TaskParams.mayAllocateToNode(d) && shuttingDownNodes.contains(d.getId()) == false | ||
) | ||
.collect(Collectors.toList()); | ||
final TrainedModelAllocation existingAllocation = metadata.getModelAllocation(modelId); | ||
final TrainedModelAllocationMetadata.Builder builder = TrainedModelAllocationMetadata.builder(currentState); | ||
final TrainedModelAllocationMetadata.Builder builder = TrainedModelAllocationMetadata.builder(currentState); | ||
// If state is stopped, this indicates the node process is closed, remove the node from the allocation | ||
if (request.getRoutingState().getState().equals(RoutingState.STOPPED)) { | ||
if (existingAllocation == null || existingAllocation.isRoutedToNode(nodeId) == false) { | ||
|
@@ -313,20 +305,20 @@ static ClusterState updateModelRoutingTable(ClusterState currentState, UpdateTra | |
} | ||
// If we are stopping, don't update anything | ||
if (existingAllocation.getAllocationState().equals(AllocationState.STOPPING)) { | ||
logger.debug(() -> new ParameterizedMessage( | ||
"[{}] requested update from node [{}] to update route state to [{}]", | ||
modelId, | ||
nodeId, | ||
request.getRoutingState() | ||
)); | ||
logger.debug( | ||
() -> new ParameterizedMessage( | ||
"[{}] requested update from node [{}] to update route state to [{}]", | ||
modelId, | ||
nodeId, | ||
request.getRoutingState() | ||
) | ||
); | ||
return currentState; | ||
} | ||
if (existingAllocation.isRoutedToNode(nodeId) == false) { | ||
throw new ResourceNotFoundException("allocation for model with id [{}]] is not routed to node [{}]", modelId, nodeId); | ||
} | ||
builder.getAllocation(modelId) | ||
.updateExistingRoutingEntry(nodeId, request.getRoutingState()) | ||
.calculateAndSetAllocationState(); | ||
builder.getAllocation(modelId).updateExistingRoutingEntry(nodeId, request.getRoutingState()).calculateAndSetAllocationState(); | ||
|
||
return update(currentState, builder); | ||
} | ||
|
@@ -342,7 +334,7 @@ static ClusterState removeAllocation(ClusterState currentState, String modelId) | |
static ClusterState removeAllAllocations(ClusterState currentState) { | ||
if (TrainedModelAllocationMetadata.fromState(currentState).modelAllocations().isEmpty()) { | ||
return currentState; | ||
}; | ||
} | ||
return ClusterState.builder(currentState) | ||
.metadata( | ||
Metadata.builder(currentState.metadata()) | ||
|
@@ -356,64 +348,62 @@ ClusterState addRemoveAllocationNodes(ClusterState currentState) { | |
final TrainedModelAllocationMetadata previousState = TrainedModelAllocationMetadata.fromState(currentState); | ||
final TrainedModelAllocationMetadata.Builder builder = TrainedModelAllocationMetadata.builder(currentState); | ||
Set<String> shuttingDownNodes = nodesShuttingDown(currentState); | ||
Set<String> currentNotShuttingDownNodes = currentState.getNodes() | ||
Map<String, DiscoveryNode> currentEligibleNodes = currentState.getNodes() | ||
.getAllNodes() | ||
.stream() | ||
.map(DiscoveryNode::getId) | ||
.filter(id -> shuttingDownNodes.contains(id) == false) | ||
.collect(Collectors.toSet()); | ||
// TODO: make more efficient, right now this is O(nm) where n = sizeof(models) and m = sizeof(nodes) | ||
// It could probably be O(max(n, m)) | ||
// Add nodes and keep track of currently routed nodes | ||
// Should we indicate a partial allocation somehow if some nodes don't have space? | ||
for (Map.Entry<String, TrainedModelAllocation> modelAllocationEntry : previousState.modelAllocations().entrySet()) { | ||
// Don't bother adding/removing nodes if this allocation is stopping | ||
if (modelAllocationEntry.getValue().getAllocationState().equals(AllocationState.STOPPING)) { | ||
continue; | ||
} | ||
final String modelId = modelAllocationEntry.getKey(); | ||
Map<String, String> nodeToReason = new TreeMap<>(); | ||
for (DiscoveryNode node : currentState.getNodes()) { | ||
// Only add the route if the node is NOT shutting down, this would be a weird case of the node | ||
// just being added to the cluster and immediately shutting down... | ||
if (shuttingDownNodes.contains(node.getId()) == false | ||
&& StartTrainedModelDeploymentAction.TaskParams.mayAllocateToNode(node) | ||
&& modelAllocationEntry.getValue().isRoutedToNode(node.getId()) == false) { | ||
Optional<String> failure = nodeHasCapacity(currentState, modelAllocationEntry.getValue().getTaskParams(), node); | ||
if (failure.isPresent()) { | ||
nodeToReason.put(node.getName(), failure.get()); | ||
} else { | ||
builder.getAllocation(modelId).addNewRoutingEntry(node.getId()); | ||
// TODO: Change when we update `mayAllocateToNode` | ||
.filter(node -> shuttingDownNodes.contains(node.getId()) == false | ||
&& StartTrainedModelDeploymentAction.TaskParams.mayAllocateToNode(node)) | ||
.collect(Collectors.toMap(DiscoveryNode::getId, Function.identity())); | ||
// TODO: make more efficient, we iterate every entry, sorting by nodes routed (fewest to most) | ||
previousState.modelAllocations() | ||
.entrySet() | ||
.stream() | ||
.filter(entry -> entry.getValue().getAllocationState().equals(AllocationState.STOPPING) == false) | ||
.sorted(Comparator.comparing(e -> e.getValue().getNodeRoutingTable().size())) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is obviously not optimal, but I am not 100% sure how else to do it. One issue here (besides performance), is that we always use the counts based on the original cluster state. Meaning, as models are allocated, our iteration order isn't updated in the same pass. This is probably acceptable. Another option is that we maybe iterate on nodes in the outer loop? This way all allocations attempt the same node and then we move on? The only downside here is that the order of inner iteration may change all the time. Another option is attempting this same thing with a priority queue that updates on each iteration, popping off all the least allocations and rebuilding each time, but this seems pretty expensive to me as the whole heap would be rebuit on every pass on a new node... There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. A low hanging fruit here is to apply the filter of Not sure I can think of something else from a first look. But if we make this optimization, then we can run a few tests to see how long it takes to run this with realistic parameter values. It might not be a problem at all. |
||
.forEach(modelAllocationEntry -> { | ||
final String modelId = modelAllocationEntry.getKey(); | ||
Map<String, String> nodeToReason = new TreeMap<>(); | ||
for (DiscoveryNode node : currentEligibleNodes.values()) { | ||
if (modelAllocationEntry.getValue().isRoutedToNode(node.getId()) == false) { | ||
Optional<String> failure = builder.isChanged() ? | ||
// We use the builder only if we have changed, there is no point in creating a new object if we haven't changed | ||
nodeHasCapacity(currentState, builder, modelAllocationEntry.getValue().getTaskParams(), node) : | ||
nodeHasCapacity(currentState, modelAllocationEntry.getValue().getTaskParams(), node); | ||
if (failure.isPresent()) { | ||
nodeToReason.put(node.getName(), failure.get()); | ||
} else { | ||
builder.getAllocation(modelId).addNewRoutingEntry(node.getId()); | ||
} | ||
} | ||
} | ||
} | ||
if (nodeToReason.isEmpty() == false) { | ||
builder.getAllocation(modelId) | ||
.setReason( | ||
nodeToReason.entrySet() | ||
.stream() | ||
.map( | ||
entry -> String.format( | ||
Locale.ROOT, | ||
"Not allocating on node [%s]. Reason: %s", | ||
entry.getKey(), | ||
entry.getValue() | ||
if (nodeToReason.isEmpty() == false) { | ||
builder.getAllocation(modelId) | ||
.setReason( | ||
nodeToReason.entrySet() | ||
.stream() | ||
.map( | ||
entry -> String.format( | ||
Locale.ROOT, | ||
"Not allocating on node [%s]. Reason: %s", | ||
entry.getKey(), | ||
entry.getValue() | ||
) | ||
) | ||
) | ||
.collect(Collectors.joining("|")) | ||
); | ||
} else { | ||
builder.getAllocation(modelId).clearReason(); | ||
} | ||
for (String nodeId : modelAllocationEntry.getValue().getNodeRoutingTable().keySet()) { | ||
if (currentNotShuttingDownNodes.contains(nodeId) == false) { | ||
builder.getAllocation(modelId).removeRoutingEntry(nodeId); | ||
.collect(Collectors.joining("|")) | ||
); | ||
} else { | ||
builder.getAllocation(modelId).clearReason(); | ||
} | ||
} | ||
// It may be we moved from STARTED to PARTIALLY_STARTED with the addition of new nodes | ||
// Or moved from PARTIALLY_STARTED to STARTED if a node was removed | ||
builder.getAllocation(modelId).calculateAndSetAllocationState(); | ||
} | ||
for (String nodeId : modelAllocationEntry.getValue().getNodeRoutingTable().keySet()) { | ||
if (currentEligibleNodes.containsKey(nodeId) == false) { | ||
builder.getAllocation(modelId).removeRoutingEntry(nodeId); | ||
} | ||
} | ||
// It may be we moved from STARTED to PARTIALLY_STARTED with the addition of new nodes | ||
// Or moved from PARTIALLY_STARTED to STARTED if a node was removed | ||
builder.getAllocation(modelId).calculateAndSetAllocationState(); | ||
}); | ||
return update(currentState, builder); | ||
} | ||
|
||
|
@@ -448,8 +438,33 @@ static boolean shouldAllocateModels(final ClusterChangedEvent event) { | |
|
||
Optional<String> nodeHasCapacity(ClusterState state, StartTrainedModelDeploymentAction.TaskParams params, DiscoveryNode node) { | ||
NodeLoad load = nodeLoadDetector.detectNodeLoad(state, true, node, Integer.MAX_VALUE, maxMemoryPercentage, useAuto); | ||
return handleNodeLoad(load, node.getId(), params); | ||
} | ||
|
||
/** | ||
* Gather current node capacity taking the passed allocation metadata into account instead of the one stored in cluster state. | ||
*/ | ||
Optional<String> nodeHasCapacity( | ||
ClusterState state, | ||
TrainedModelAllocationMetadata.Builder builder, | ||
StartTrainedModelDeploymentAction.TaskParams params, | ||
DiscoveryNode node | ||
) { | ||
NodeLoad load = nodeLoadDetector.detectNodeLoad( | ||
state, | ||
builder.build(), | ||
true, | ||
node, | ||
Integer.MAX_VALUE, | ||
maxMemoryPercentage, | ||
useAuto | ||
); | ||
return handleNodeLoad(load, node.getId(), params); | ||
} | ||
|
||
Optional<String> handleNodeLoad(NodeLoad load, String nodeId, StartTrainedModelDeploymentAction.TaskParams params) { | ||
if (Strings.isNullOrEmpty(load.getError()) == false) { | ||
logger.warn("[{}] failed to calculate current node load with error [{}]", params.getModelId(), node.getId()); | ||
logger.warn("[{}] failed to calculate current node load with error [{}]", params.getModelId(), nodeId); | ||
return Optional.of(load.getError()); | ||
} | ||
if (load.getFreeMemory() < params.estimateMemoryUsageBytes()) { | ||
|
@@ -464,8 +479,7 @@ Optional<String> nodeHasCapacity(ClusterState state, StartTrainedModelDeployment | |
load.getAssignedJobMemory(), | ||
ByteSizeValue.ofBytes(load.getAssignedJobMemory()).toString(), | ||
params.estimateMemoryUsageBytes(), | ||
ByteSizeValue.ofBytes(params.estimateMemoryUsageBytes()).toString() | ||
} | ||
ByteSizeValue.ofBytes(params.estimateMemoryUsageBytes()).toString() } | ||
) | ||
); | ||
} | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is old, unused code, decided to clean it up. Unrelated to the change in this PR.