Skip to content

Commit

Permalink
[ML] add deployed native models to inference_stats in trained model s…
Browse files Browse the repository at this point in the history
…tats response (elastic#88187)

This adds a valid `inference_stats` section for deployed native models.

`inference_stats` is effectively a sub-set of the `deployment_stats`. It's a high level view of the overall stats of the model, deployment_stats contains more detailed information around types of errors seen, throughput, etc.
  • Loading branch information
benwtrent authored Jul 5, 2022
1 parent 66b5189 commit e5d1d10
Show file tree
Hide file tree
Showing 5 changed files with 141 additions and 3 deletions.
5 changes: 5 additions & 0 deletions docs/changelog/88187.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
pr: 88187
summary: Add deployed native models to `inference_stats` in trained model stats response
area: Machine Learning
type: enhancement
issues: []
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.common.util.Maps;
import org.elasticsearch.core.RestApiVersion;
import org.elasticsearch.ingest.IngestStats;
import org.elasticsearch.xcontent.ParseField;
Expand Down Expand Up @@ -269,8 +270,19 @@ public Builder setInferenceStatsByModelId(Map<String, InferenceStats> inferenceS
return this;
}

/**
* This sets the overall stats map and adds the models to the overall inference stats map
* @param assignmentStatsMap map of model_id to assignment stats
* @return the builder with inference stats map updated and assignment stats map set
*/
public Builder setDeploymentStatsByModelId(Map<String, AssignmentStats> assignmentStatsMap) {
this.assignmentStatsMap = assignmentStatsMap;
if (inferenceStatsMap == null) {
inferenceStatsMap = Maps.newHashMapWithExpectedSize(assignmentStatsMap.size());
}
assignmentStatsMap.forEach(
(modelId, assignmentStats) -> inferenceStatsMap.put(modelId, assignmentStats.getOverallInferenceStats())
);
return this;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import org.elasticsearch.xcontent.ToXContentObject;
import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xpack.core.ml.action.StartTrainedModelDeploymentAction;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceStats;

import java.io.IOException;
import java.time.Instant;
Expand Down Expand Up @@ -437,6 +438,23 @@ public AssignmentStats setReason(String reason) {
return this;
}

/**
* @return The overall inference stats for the model assignment
*/
public InferenceStats getOverallInferenceStats() {
return new InferenceStats(
0L,
nodeStats.stream().filter(n -> n.getInferenceCount().isPresent()).mapToLong(n -> n.getInferenceCount().get()).sum(),
// This is for ALL failures, so sum the error counts, timeouts, and rejections
nodeStats.stream().mapToLong(n -> n.getErrorCount() + n.getTimeoutCount() + n.getRejectedExecutionCount()).sum(),
// TODO Update when we actually have cache miss/hit values
0L,
modelId,
null,
Instant.now()
);
}

@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,22 +10,23 @@
import org.elasticsearch.Version;
import org.elasticsearch.cluster.node.DiscoveryNode;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.common.transport.TransportAddress;
import org.elasticsearch.test.AbstractWireSerializingTestCase;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceStats;

import java.net.InetAddress;
import java.time.Instant;
import java.util.ArrayList;
import java.util.Comparator;
import java.util.List;

import static org.hamcrest.Matchers.equalTo;

public class AssignmentStatsTests extends AbstractWireSerializingTestCase<AssignmentStats> {

public static AssignmentStats randomDeploymentStats() {
List<AssignmentStats.NodeStats> nodeStatsList = new ArrayList<>();
int numNodes = randomIntBetween(1, 4);
for (int i = 0; i < numNodes; i++) {
var node = new DiscoveryNode("node_" + i, new TransportAddress(InetAddress.getLoopbackAddress(), 9300), Version.CURRENT);
var node = new DiscoveryNode("node_" + i, buildNewFakeTransportAddress(), Version.CURRENT);
if (randomBoolean()) {
nodeStatsList.add(randomNodeStats(node));
} else {
Expand Down Expand Up @@ -82,6 +83,106 @@ public static AssignmentStats.NodeStats randomNodeStats(DiscoveryNode node) {
);
}

public void testGetOverallInferenceStats() {
String modelId = randomAlphaOfLength(10);

AssignmentStats existingStats = new AssignmentStats(
modelId,
randomBoolean() ? null : randomIntBetween(1, 8),
randomBoolean() ? null : randomIntBetween(1, 8),
randomBoolean() ? null : randomIntBetween(1, 10000),
Instant.now(),
List.of(
AssignmentStats.NodeStats.forStartedState(
new DiscoveryNode("node_started_1", buildNewFakeTransportAddress(), Version.CURRENT),
10L,
randomDoubleBetween(0.0, 100.0, true),
randomIntBetween(1, 10),
5,
12,
3,
Instant.now(),
Instant.now(),
randomIntBetween(1, 2),
randomIntBetween(1, 2),
randomNonNegativeLong(),
randomNonNegativeLong(),
null
),
AssignmentStats.NodeStats.forStartedState(
new DiscoveryNode("node_started_2", buildNewFakeTransportAddress(), Version.CURRENT),
12L,
randomDoubleBetween(0.0, 100.0, true),
randomIntBetween(1, 10),
15,
4,
2,
Instant.now(),
Instant.now(),
randomIntBetween(1, 2),
randomIntBetween(1, 2),
randomNonNegativeLong(),
randomNonNegativeLong(),
null
),
AssignmentStats.NodeStats.forNotStartedState(
new DiscoveryNode("node_not_started_3", buildNewFakeTransportAddress(), Version.CURRENT),
randomFrom(RoutingState.values()),
randomBoolean() ? null : "a good reason"
)
)
);
InferenceStats stats = existingStats.getOverallInferenceStats();
assertThat(stats.getModelId(), equalTo(modelId));
assertThat(stats.getInferenceCount(), equalTo(22L));
assertThat(stats.getFailureCount(), equalTo(41L));
}

public void testGetOverallInferenceStatsWithNoNodes() {
String modelId = randomAlphaOfLength(10);

AssignmentStats existingStats = new AssignmentStats(
modelId,
randomBoolean() ? null : randomIntBetween(1, 8),
randomBoolean() ? null : randomIntBetween(1, 8),
randomBoolean() ? null : randomIntBetween(1, 10000),
Instant.now(),
List.of()
);
InferenceStats stats = existingStats.getOverallInferenceStats();
assertThat(stats.getModelId(), equalTo(modelId));
assertThat(stats.getInferenceCount(), equalTo(0L));
assertThat(stats.getFailureCount(), equalTo(0L));
}

public void testGetOverallInferenceStatsWithOnlyStoppedNodes() {
String modelId = randomAlphaOfLength(10);

AssignmentStats existingStats = new AssignmentStats(
modelId,
randomBoolean() ? null : randomIntBetween(1, 8),
randomBoolean() ? null : randomIntBetween(1, 8),
randomBoolean() ? null : randomIntBetween(1, 10000),
Instant.now(),
List.of(
AssignmentStats.NodeStats.forNotStartedState(
new DiscoveryNode("node_not_started_1", buildNewFakeTransportAddress(), Version.CURRENT),
randomFrom(RoutingState.values()),
randomBoolean() ? null : "a good reason"
),
AssignmentStats.NodeStats.forNotStartedState(
new DiscoveryNode("node_not_started_2", buildNewFakeTransportAddress(), Version.CURRENT),
randomFrom(RoutingState.values()),
randomBoolean() ? null : "a good reason"
)
)
);
InferenceStats stats = existingStats.getOverallInferenceStats();
assertThat(stats.getModelId(), equalTo(modelId));
assertThat(stats.getInferenceCount(), equalTo(0L));
assertThat(stats.getFailureCount(), equalTo(0L));
}

@Override
protected Writeable.Reader<AssignmentStats> instanceReader() {
return AssignmentStats::new;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -248,6 +248,7 @@ public void testDeploymentStats() throws IOException {
String statusState = (String) XContentMapValues.extractValue("deployment_stats.allocation_status.state", stats.get(0));
assertThat(responseMap.toString(), statusState, is(not(nullValue())));
assertThat(AllocationStatus.State.fromString(statusState), greaterThanOrEqualTo(state));
assertThat(XContentMapValues.extractValue("inference_stats", stats.get(0)), is(not(nullValue())));

Integer byteSize = (Integer) XContentMapValues.extractValue("model_size_stats.model_size_bytes", stats.get(0));
assertThat(responseMap.toString(), byteSize, is(not(nullValue())));
Expand Down Expand Up @@ -340,6 +341,7 @@ public void testLiveDeploymentStats() throws IOException {
assertAtLeastOneOfTheseIsNotNull("last_access", nodes);
assertAtLeastOneOfTheseIsNotNull("average_inference_time_ms", nodes);

assertThat((Integer) XContentMapValues.extractValue("inference_stats.inference_count", stats.get(0)), equalTo(2));
int inferenceCount = sumInferenceCountOnNodes(nodes);
assertThat(inferenceCount, equalTo(2));
}
Expand Down

0 comments on commit e5d1d10

Please sign in to comment.