Skip to content

Commit

Permalink
more ml test fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
jakelandis committed Sep 23, 2024
1 parent 5c07f81 commit 7c8a24c
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,12 @@
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.util.concurrent.ThreadContext;
import org.elasticsearch.common.xcontent.support.XContentMapValues;
import org.elasticsearch.core.RestApiVersion;
import org.elasticsearch.core.Strings;
import org.elasticsearch.core.UpdateForV9;
import org.elasticsearch.test.rest.RestTestLegacyFeatures;
import org.elasticsearch.upgrades.FullClusterRestartUpgradeStatus;
import org.elasticsearch.xcontent.XContentType;
import org.elasticsearch.xpack.core.ml.inference.assignment.AllocationStatus;
import org.junit.Before;

Expand Down Expand Up @@ -195,14 +197,30 @@ private Response startDeployment(String modelId) throws IOException {
}

private Response startDeployment(String modelId, String waitForState) throws IOException {
String inferenceThreadParamName = "threads_per_allocation";
String modelThreadParamName = "number_of_allocations";
String compatibleHeader = null;
if (isRunningAgainstOldCluster()) {
compatibleHeader = compatibleMediaType(XContentType.VND_JSON, RestApiVersion.V_8);
inferenceThreadParamName = "inference_threads";
modelThreadParamName = "model_threads";
}

Request request = new Request(
"POST",
"/_ml/trained_models/"
+ modelId
+ "/deployment/_start?timeout=40s&wait_for="
+ waitForState
+ "&inference_threads=1&model_threads=1"
+ "&"
+ inferenceThreadParamName
+ "=1&"
+ modelThreadParamName
+ "=1"
);
if (compatibleHeader != null) {
request.setOptions(request.getOptions().toBuilder().addHeader("Accept", compatibleHeader).build());
}
request.setOptions(request.getOptions().toBuilder().setWarningsHandler(PERMISSIVE).build());
var response = client().performRequest(request);
assertOK(response);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,13 @@
import org.elasticsearch.client.Response;
import org.elasticsearch.common.unit.ByteSizeValue;
import org.elasticsearch.common.xcontent.support.XContentMapValues;
import org.elasticsearch.core.RestApiVersion;
import org.elasticsearch.core.Strings;
import org.elasticsearch.core.UpdateForV9;
import org.elasticsearch.logging.LogManager;
import org.elasticsearch.logging.Logger;
import org.elasticsearch.test.rest.RestTestLegacyFeatures;
import org.elasticsearch.xcontent.XContentType;

import java.io.IOException;
import java.util.ArrayList;
Expand Down Expand Up @@ -278,14 +280,30 @@ private Response startDeployment(String modelId) throws IOException {
}

private Response startDeployment(String modelId, String waitForState) throws IOException {
String inferenceThreadParamName = "threads_per_allocation";
String modelThreadParamName = "number_of_allocations";
String compatibleHeader = null;
if (CLUSTER_TYPE.equals(ClusterType.OLD) || CLUSTER_TYPE.equals(ClusterType.MIXED)) {
compatibleHeader = compatibleMediaType(XContentType.VND_JSON, RestApiVersion.V_8);
inferenceThreadParamName = "inference_threads";
modelThreadParamName = "model_threads";
}

Request request = new Request(
"POST",
"/_ml/trained_models/"
+ modelId
+ "/deployment/_start?timeout=40s&wait_for="
+ waitForState
+ "&inference_threads=1&model_threads=1"
+ "&"
+ inferenceThreadParamName
+ "=1&"
+ modelThreadParamName
+ "=1"
);
if (compatibleHeader != null) {
request.setOptions(request.getOptions().toBuilder().addHeader("Accept", compatibleHeader).build());
}
request.setOptions(request.getOptions().toBuilder().setWarningsHandler(PERMISSIVE).build());
var response = client().performRequest(request);
assertOK(response);
Expand Down

0 comments on commit 7c8a24c

Please sign in to comment.