diff --git a/docs/changelog/88187.yaml b/docs/changelog/88187.yaml new file mode 100644 index 0000000000000..17067c06c5d3c --- /dev/null +++ b/docs/changelog/88187.yaml @@ -0,0 +1,5 @@ +pr: 88187 +summary: Add deployed native models to `inference_stats` in trained model stats response +area: Machine Learning +type: enhancement +issues: [] diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/GetTrainedModelsStatsAction.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/GetTrainedModelsStatsAction.java index b07216018cff3..a99b112490ab2 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/GetTrainedModelsStatsAction.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/GetTrainedModelsStatsAction.java @@ -12,6 +12,7 @@ import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.common.util.Maps; import org.elasticsearch.core.RestApiVersion; import org.elasticsearch.ingest.IngestStats; import org.elasticsearch.xcontent.ParseField; @@ -269,8 +270,19 @@ public Builder setInferenceStatsByModelId(Map inferenceS return this; } + /** + * This sets the overall stats map and adds the models to the overall inference stats map + * @param assignmentStatsMap map of model_id to assignment stats + * @return the builder with inference stats map updated and assignment stats map set + */ public Builder setDeploymentStatsByModelId(Map assignmentStatsMap) { this.assignmentStatsMap = assignmentStatsMap; + if (inferenceStatsMap == null) { + inferenceStatsMap = Maps.newHashMapWithExpectedSize(assignmentStatsMap.size()); + } + assignmentStatsMap.forEach( + (modelId, assignmentStats) -> inferenceStatsMap.put(modelId, assignmentStats.getOverallInferenceStats()) + ); return this; } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/assignment/AssignmentStats.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/assignment/AssignmentStats.java index 6da7c9ebeb40b..08b280a6fb48b 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/assignment/AssignmentStats.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/assignment/AssignmentStats.java @@ -17,6 +17,7 @@ import org.elasticsearch.xcontent.ToXContentObject; import org.elasticsearch.xcontent.XContentBuilder; import org.elasticsearch.xpack.core.ml.action.StartTrainedModelDeploymentAction; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceStats; import java.io.IOException; import java.time.Instant; @@ -437,6 +438,23 @@ public AssignmentStats setReason(String reason) { return this; } + /** + * @return The overall inference stats for the model assignment + */ + public InferenceStats getOverallInferenceStats() { + return new InferenceStats( + 0L, + nodeStats.stream().filter(n -> n.getInferenceCount().isPresent()).mapToLong(n -> n.getInferenceCount().get()).sum(), + // This is for ALL failures, so sum the error counts, timeouts, and rejections + nodeStats.stream().mapToLong(n -> n.getErrorCount() + n.getTimeoutCount() + n.getRejectedExecutionCount()).sum(), + // TODO Update when we actually have cache miss/hit values + 0L, + modelId, + null, + Instant.now() + ); + } + @Override public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { builder.startObject(); diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/assignment/AssignmentStatsTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/assignment/AssignmentStatsTests.java index fec3016cce058..02697e7119d6c 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/assignment/AssignmentStatsTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/assignment/AssignmentStatsTests.java @@ -10,22 +10,23 @@ import org.elasticsearch.Version; import org.elasticsearch.cluster.node.DiscoveryNode; import org.elasticsearch.common.io.stream.Writeable; -import org.elasticsearch.common.transport.TransportAddress; import org.elasticsearch.test.AbstractWireSerializingTestCase; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceStats; -import java.net.InetAddress; import java.time.Instant; import java.util.ArrayList; import java.util.Comparator; import java.util.List; +import static org.hamcrest.Matchers.equalTo; + public class AssignmentStatsTests extends AbstractWireSerializingTestCase { public static AssignmentStats randomDeploymentStats() { List nodeStatsList = new ArrayList<>(); int numNodes = randomIntBetween(1, 4); for (int i = 0; i < numNodes; i++) { - var node = new DiscoveryNode("node_" + i, new TransportAddress(InetAddress.getLoopbackAddress(), 9300), Version.CURRENT); + var node = new DiscoveryNode("node_" + i, buildNewFakeTransportAddress(), Version.CURRENT); if (randomBoolean()) { nodeStatsList.add(randomNodeStats(node)); } else { @@ -82,6 +83,106 @@ public static AssignmentStats.NodeStats randomNodeStats(DiscoveryNode node) { ); } + public void testGetOverallInferenceStats() { + String modelId = randomAlphaOfLength(10); + + AssignmentStats existingStats = new AssignmentStats( + modelId, + randomBoolean() ? null : randomIntBetween(1, 8), + randomBoolean() ? null : randomIntBetween(1, 8), + randomBoolean() ? null : randomIntBetween(1, 10000), + Instant.now(), + List.of( + AssignmentStats.NodeStats.forStartedState( + new DiscoveryNode("node_started_1", buildNewFakeTransportAddress(), Version.CURRENT), + 10L, + randomDoubleBetween(0.0, 100.0, true), + randomIntBetween(1, 10), + 5, + 12, + 3, + Instant.now(), + Instant.now(), + randomIntBetween(1, 2), + randomIntBetween(1, 2), + randomNonNegativeLong(), + randomNonNegativeLong(), + null + ), + AssignmentStats.NodeStats.forStartedState( + new DiscoveryNode("node_started_2", buildNewFakeTransportAddress(), Version.CURRENT), + 12L, + randomDoubleBetween(0.0, 100.0, true), + randomIntBetween(1, 10), + 15, + 4, + 2, + Instant.now(), + Instant.now(), + randomIntBetween(1, 2), + randomIntBetween(1, 2), + randomNonNegativeLong(), + randomNonNegativeLong(), + null + ), + AssignmentStats.NodeStats.forNotStartedState( + new DiscoveryNode("node_not_started_3", buildNewFakeTransportAddress(), Version.CURRENT), + randomFrom(RoutingState.values()), + randomBoolean() ? null : "a good reason" + ) + ) + ); + InferenceStats stats = existingStats.getOverallInferenceStats(); + assertThat(stats.getModelId(), equalTo(modelId)); + assertThat(stats.getInferenceCount(), equalTo(22L)); + assertThat(stats.getFailureCount(), equalTo(41L)); + } + + public void testGetOverallInferenceStatsWithNoNodes() { + String modelId = randomAlphaOfLength(10); + + AssignmentStats existingStats = new AssignmentStats( + modelId, + randomBoolean() ? null : randomIntBetween(1, 8), + randomBoolean() ? null : randomIntBetween(1, 8), + randomBoolean() ? null : randomIntBetween(1, 10000), + Instant.now(), + List.of() + ); + InferenceStats stats = existingStats.getOverallInferenceStats(); + assertThat(stats.getModelId(), equalTo(modelId)); + assertThat(stats.getInferenceCount(), equalTo(0L)); + assertThat(stats.getFailureCount(), equalTo(0L)); + } + + public void testGetOverallInferenceStatsWithOnlyStoppedNodes() { + String modelId = randomAlphaOfLength(10); + + AssignmentStats existingStats = new AssignmentStats( + modelId, + randomBoolean() ? null : randomIntBetween(1, 8), + randomBoolean() ? null : randomIntBetween(1, 8), + randomBoolean() ? null : randomIntBetween(1, 10000), + Instant.now(), + List.of( + AssignmentStats.NodeStats.forNotStartedState( + new DiscoveryNode("node_not_started_1", buildNewFakeTransportAddress(), Version.CURRENT), + randomFrom(RoutingState.values()), + randomBoolean() ? null : "a good reason" + ), + AssignmentStats.NodeStats.forNotStartedState( + new DiscoveryNode("node_not_started_2", buildNewFakeTransportAddress(), Version.CURRENT), + randomFrom(RoutingState.values()), + randomBoolean() ? null : "a good reason" + ) + ) + ); + InferenceStats stats = existingStats.getOverallInferenceStats(); + assertThat(stats.getModelId(), equalTo(modelId)); + assertThat(stats.getInferenceCount(), equalTo(0L)); + assertThat(stats.getFailureCount(), equalTo(0L)); + } + @Override protected Writeable.Reader instanceReader() { return AssignmentStats::new; diff --git a/x-pack/plugin/ml/qa/native-multi-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/PyTorchModelIT.java b/x-pack/plugin/ml/qa/native-multi-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/PyTorchModelIT.java index e1b8319167a35..1f315855fc04c 100644 --- a/x-pack/plugin/ml/qa/native-multi-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/PyTorchModelIT.java +++ b/x-pack/plugin/ml/qa/native-multi-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/PyTorchModelIT.java @@ -248,6 +248,7 @@ public void testDeploymentStats() throws IOException { String statusState = (String) XContentMapValues.extractValue("deployment_stats.allocation_status.state", stats.get(0)); assertThat(responseMap.toString(), statusState, is(not(nullValue()))); assertThat(AllocationStatus.State.fromString(statusState), greaterThanOrEqualTo(state)); + assertThat(XContentMapValues.extractValue("inference_stats", stats.get(0)), is(not(nullValue()))); Integer byteSize = (Integer) XContentMapValues.extractValue("model_size_stats.model_size_bytes", stats.get(0)); assertThat(responseMap.toString(), byteSize, is(not(nullValue()))); @@ -340,6 +341,7 @@ public void testLiveDeploymentStats() throws IOException { assertAtLeastOneOfTheseIsNotNull("last_access", nodes); assertAtLeastOneOfTheseIsNotNull("average_inference_time_ms", nodes); + assertThat((Integer) XContentMapValues.extractValue("inference_stats.inference_count", stats.get(0)), equalTo(2)); int inferenceCount = sumInferenceCountOnNodes(nodes); assertThat(inferenceCount, equalTo(2)); }