Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ML] Distribute model allocations respecting node allocated processors #87366

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
e6f3b51
[ML] Improve scalability of NLP models
dimitris-athanasiou May 5, 2022
e306997
Fix tests
dimitris-athanasiou Jun 6, 2022
8ed60a8
Minor fixes
dimitris-athanasiou Jun 6, 2022
6f25e73
Fix rollin upgrade test
dimitris-athanasiou Jun 7, 2022
3e27562
Remove assertion that depends on which node is coordinating
dimitris-athanasiou Jun 7, 2022
95eabd2
Update docs/changelog/87366.yaml
dimitris-athanasiou Jun 7, 2022
c6920ce
Merge branch 'master' into improve-scalability-of-nlp-models
elasticmachine Jun 9, 2022
7895e90
Simply get allocated processors using EsExecutors.allocatedProcessors
dimitris-athanasiou Jun 9, 2022
5f5187b
Address review comments
dimitris-athanasiou Jun 9, 2022
70d4254
Skip rebalancing when cluster not fully on version supporting it
dimitris-athanasiou Jun 9, 2022
7a04304
Allow partial updates to routing info
dimitris-athanasiou Jun 9, 2022
f74cf35
On mixed cluster select started node on older version
dimitris-athanasiou Jun 9, 2022
7638a69
fix formatting
dimitris-athanasiou Jun 9, 2022
927cc5d
Some review comments
dimitris-athanasiou Jun 14, 2022
38cae6f
Revert "On mixed cluster select started node on older version"
dimitris-athanasiou Jun 14, 2022
99a3c8e
In mixed cluster remove routing to removed or shutting down nodes
dimitris-athanasiou Jun 14, 2022
65e4ed4
Some more comments
dimitris-athanasiou Jun 14, 2022
2595710
Fix format
dimitris-athanasiou Jun 14, 2022
71b6796
Merge branch 'master' into improve-scalability-of-nlp-models
dimitris-athanasiou Jun 15, 2022
5e3025f
Remove use of ParameterizedMessage
dimitris-athanasiou Jun 15, 2022
d90f7dd
Improve logging
dimitris-athanasiou Jun 15, 2022
1453fec
Merge branch 'master' into improve-scalability-of-nlp-models
dimitris-athanasiou Jun 20, 2022
75f4dcd
Merge branch 'master' into improve-scalability-of-nlp-models
elasticmachine Jun 20, 2022
3852f63
Fix merge problems
dimitris-athanasiou Jun 20, 2022
292b369
Minor rename
dimitris-athanasiou Jun 20, 2022
82a8704
Simplify log statement
dimitris-athanasiou Jun 20, 2022
3ba736c
Fix memory counting when merging assignment plan
dimitris-athanasiou Jun 20, 2022
6533527
Wrong Strings import
dimitris-athanasiou Jun 20, 2022
816db42
More logging and fixing an NPE
dimitris-athanasiou Jun 21, 2022
c17fa8b
Correctly report number_of_allocations for entire deployment
dimitris-athanasiou Jun 21, 2022
029fe62
Filter out removed nodes from current assignments
dimitris-athanasiou Jun 21, 2022
8f45921
Upgrade some logging from trace to debug
dimitris-athanasiou Jun 21, 2022
13ccfd8
Throw Exception from rebalance
dimitris-athanasiou Jun 21, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions docs/changelog/87366.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
pr: 87366
summary: Improve scalability of NLP models
area: Machine Learning
type: enhancement
issues: []
Original file line number Diff line number Diff line change
Expand Up @@ -13,36 +13,36 @@
import org.elasticsearch.action.support.master.MasterNodeRequest;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.xpack.core.ml.inference.assignment.RoutingStateAndReason;
import org.elasticsearch.xpack.core.ml.inference.assignment.RoutingInfoUpdate;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;

import java.io.IOException;
import java.util.Objects;

public class UpdateTrainedModelAssignmentStateAction extends ActionType<AcknowledgedResponse> {
public static final UpdateTrainedModelAssignmentStateAction INSTANCE = new UpdateTrainedModelAssignmentStateAction();
public class UpdateTrainedModelAssignmentRoutingInfoAction extends ActionType<AcknowledgedResponse> {
public static final UpdateTrainedModelAssignmentRoutingInfoAction INSTANCE = new UpdateTrainedModelAssignmentRoutingInfoAction();
public static final String NAME = "cluster:internal/xpack/ml/model_allocation/update";

private UpdateTrainedModelAssignmentStateAction() {
private UpdateTrainedModelAssignmentRoutingInfoAction() {
super(NAME, AcknowledgedResponse::readFrom);
}

public static class Request extends MasterNodeRequest<Request> {
private final String nodeId;
private final String modelId;
private final RoutingStateAndReason routingState;
private final RoutingInfoUpdate update;

public Request(String nodeId, String modelId, RoutingStateAndReason routingState) {
public Request(String nodeId, String modelId, RoutingInfoUpdate update) {
this.nodeId = ExceptionsHelper.requireNonNull(nodeId, "node_id");
this.modelId = ExceptionsHelper.requireNonNull(modelId, "model_id");
this.routingState = ExceptionsHelper.requireNonNull(routingState, "routing_state");
this.update = ExceptionsHelper.requireNonNull(update, "update");
}

public Request(StreamInput in) throws IOException {
super(in);
this.nodeId = in.readString();
this.modelId = in.readString();
this.routingState = new RoutingStateAndReason(in);
this.update = new RoutingInfoUpdate(in);
}

public String getNodeId() {
Expand All @@ -53,8 +53,8 @@ public String getModelId() {
return modelId;
}

public RoutingStateAndReason getRoutingState() {
return routingState;
public RoutingInfoUpdate getUpdate() {
return update;
}

@Override
Expand All @@ -67,7 +67,7 @@ public void writeTo(StreamOutput out) throws IOException {
super.writeTo(out);
out.writeString(nodeId);
out.writeString(modelId);
routingState.writeTo(out);
update.writeTo(out);
}

@Override
Expand All @@ -77,17 +77,17 @@ public boolean equals(Object o) {
Request request = (Request) o;
return Objects.equals(nodeId, request.nodeId)
&& Objects.equals(modelId, request.modelId)
&& Objects.equals(routingState, request.routingState);
&& Objects.equals(update, request.update);
}

@Override
public int hashCode() {
return Objects.hash(nodeId, modelId, routingState);
return Objects.hash(nodeId, modelId, update);
}

@Override
public String toString() {
return "Request{" + "nodeId='" + nodeId + '\'' + ", modelId='" + modelId + '\'' + ", routingState=" + routingState + '}';
return "Request{" + "nodeId='" + nodeId + '\'' + ", modelId='" + modelId + '\'' + ", update=" + update + '}';
}
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,155 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.core.ml.inference.assignment;

import org.elasticsearch.Version;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.core.Nullable;
import org.elasticsearch.xcontent.ConstructingObjectParser;
import org.elasticsearch.xcontent.ParseField;
import org.elasticsearch.xcontent.ToXContentObject;
import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xcontent.XContentParser;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;

import java.io.IOException;
import java.util.Objects;

public class RoutingInfo implements ToXContentObject, Writeable {

private static final ParseField CURRENT_ALLOCATIONS = new ParseField("current_allocations");
private static final ParseField TARGET_ALLOCATIONS = new ParseField("target_allocations");
private static final ParseField ROUTING_STATE = new ParseField("routing_state");
private static final ParseField REASON = new ParseField("reason");

private static final ConstructingObjectParser<RoutingInfo, Void> PARSER = new ConstructingObjectParser<>(
"trained_model_routing_state",
a -> new RoutingInfo((Integer) a[0], (Integer) a[1], RoutingState.fromString((String) a[2]), (String) a[3])
);
static {
PARSER.declareInt(ConstructingObjectParser.optionalConstructorArg(), CURRENT_ALLOCATIONS);
PARSER.declareInt(ConstructingObjectParser.optionalConstructorArg(), TARGET_ALLOCATIONS);
PARSER.declareString(ConstructingObjectParser.constructorArg(), ROUTING_STATE);
PARSER.declareString(ConstructingObjectParser.optionalConstructorArg(), REASON);
}

public static RoutingInfo fromXContent(XContentParser parser) {
return PARSER.apply(parser, null);
}

private final int currentAllocations;
private final int targetAllocations;
private final RoutingState state;
private final String reason;

// There may be objects in cluster state prior to 8.4 that do not contain values for currentAllocations and targetAllocations.
private RoutingInfo(
@Nullable Integer currentAllocations,
@Nullable Integer targetAllocations,
RoutingState state,
@Nullable String reason
) {
this(currentAllocations == null ? 0 : currentAllocations, targetAllocations == null ? 0 : targetAllocations, state, reason);
}

public RoutingInfo(int currentAllocations, int targetAllocations, RoutingState state, String reason) {
this.currentAllocations = currentAllocations;
this.targetAllocations = targetAllocations;
this.state = ExceptionsHelper.requireNonNull(state, ROUTING_STATE);
this.reason = reason;
}

public RoutingInfo(StreamInput in) throws IOException {
if (in.getVersion().onOrAfter(Version.V_8_4_0)) {
this.currentAllocations = in.readVInt();
this.targetAllocations = in.readVInt();
} else {
this.currentAllocations = 0;
this.targetAllocations = 0;
}
this.state = in.readEnum(RoutingState.class);
this.reason = in.readOptionalString();
}

public int getCurrentAllocations() {
return currentAllocations;
}

public int getTargetAllocations() {
return targetAllocations;
}

public RoutingState getState() {
return state;
}

@Nullable
public String getReason() {
return reason;
}

@Override
public void writeTo(StreamOutput out) throws IOException {
if (out.getVersion().onOrAfter(Version.V_8_4_0)) {
out.writeVInt(currentAllocations);
out.writeVInt(targetAllocations);
}
out.writeEnum(state);
out.writeOptionalString(reason);
}

@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
builder.field(CURRENT_ALLOCATIONS.getPreferredName(), currentAllocations);
builder.field(TARGET_ALLOCATIONS.getPreferredName(), targetAllocations);
builder.field(ROUTING_STATE.getPreferredName(), state);
if (reason != null) {
builder.field(REASON.getPreferredName(), reason);
}
builder.endObject();
return builder;
}

@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
RoutingInfo that = (RoutingInfo) o;
return currentAllocations == that.currentAllocations
&& targetAllocations == that.targetAllocations
&& state == that.state
&& Objects.equals(reason, that.reason);
}

@Override
public int hashCode() {
return Objects.hash(currentAllocations, targetAllocations, state, reason);
}

@Override
public String toString() {
return "RoutingInfo{"
+ "current_allocations="
+ currentAllocations
+ ", target_allocations="
+ targetAllocations
+ ", reason='"
+ reason
+ '\''
+ ", state="
+ state
+ '}';
}

public boolean isRoutable() {
return state == RoutingState.STARTED && currentAllocations > 0;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.core.ml.inference.assignment;

import org.elasticsearch.Version;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.io.stream.Writeable;

import java.io.IOException;
import java.util.Objects;
import java.util.Optional;

public class RoutingInfoUpdate implements Writeable {

private final Optional<Integer> numberOfAllocations;
private final Optional<RoutingStateAndReason> stateAndReason;

public static RoutingInfoUpdate updateNumberOfAllocations(int numberOfAllocations) {
return new RoutingInfoUpdate(Optional.of(numberOfAllocations), Optional.empty());
}

public static RoutingInfoUpdate updateStateAndReason(RoutingStateAndReason routingStateAndReason) {
return new RoutingInfoUpdate(Optional.empty(), Optional.of(routingStateAndReason));
}

private RoutingInfoUpdate(Optional<Integer> numberOfAllocations, Optional<RoutingStateAndReason> stateAndReason) {
this.numberOfAllocations = Objects.requireNonNull(numberOfAllocations);
this.stateAndReason = Objects.requireNonNull(stateAndReason);
}

public RoutingInfoUpdate(StreamInput in) throws IOException {
if (in.getVersion().onOrAfter(Version.V_8_4_0)) {
numberOfAllocations = Optional.ofNullable(in.readOptionalVInt());
stateAndReason = Optional.ofNullable(in.readOptionalWriteable(RoutingStateAndReason::new));
} else {
numberOfAllocations = Optional.empty();
stateAndReason = Optional.of(new RoutingStateAndReason(in));
}
}

@Override
public void writeTo(StreamOutput out) throws IOException {
if (out.getVersion().onOrAfter(Version.V_8_4_0)) {
out.writeOptionalVInt(numberOfAllocations.orElse(null));
out.writeOptionalWriteable(stateAndReason.orElse(null));
} else {
assert stateAndReason.isPresent() : "updating routing info while nodes prior to 8.4.0 should only contain state and reason";
stateAndReason.get().writeTo(out);
}
}

@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
RoutingInfoUpdate that = (RoutingInfoUpdate) o;
return Objects.equals(numberOfAllocations, that.numberOfAllocations) && Objects.equals(stateAndReason, that.stateAndReason);
}

@Override
public int hashCode() {
return Objects.hash(numberOfAllocations, stateAndReason);
}

@Override
public String toString() {
return "RoutingInfoUpdate{" + "numberOfAllocations=" + numberOfAllocations + ", stateAndReason=" + stateAndReason + '}';
}

public Optional<Integer> getNumberOfAllocations() {
return numberOfAllocations;
}

public Optional<RoutingStateAndReason> getStateAndReason() {
return stateAndReason;
}

public RoutingInfo apply(RoutingInfo routingInfo) {
int currentAllocations = numberOfAllocations.orElse(routingInfo.getCurrentAllocations());
RoutingState state = routingInfo.getState();
String reason = routingInfo.getReason();
if (stateAndReason.isPresent()) {
state = stateAndReason.get().getState();
reason = stateAndReason.get().getReason();
}
return new RoutingInfo(currentAllocations, routingInfo.getTargetAllocations(), state, reason);
}
}
Loading