Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ML] make trained model rest APIs cancellable #88009

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,13 @@
import org.elasticsearch.action.ActionRequestValidationException;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.tasks.CancellableTask;
import org.elasticsearch.tasks.Task;
import org.elasticsearch.tasks.TaskId;
import org.elasticsearch.xpack.core.action.util.PageParams;

import java.io.IOException;
import java.util.Map;
import java.util.Objects;

public abstract class AbstractGetResourcesRequest extends ActionRequest {
Expand Down Expand Up @@ -93,5 +97,12 @@ public boolean equals(Object obj) {
&& allowNoResources == other.allowNoResources;
}

@Override
public Task createTask(long id, String type, String action, TaskId parentTaskId, Map<String, String> headers) {
return new CancellableTask(id, type, action, getCancelableTaskDescription(), parentTaskId, headers);
}

public abstract String getCancelableTaskDescription();

public abstract String getResourceIdField();
}
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import org.elasticsearch.search.SearchHit;
import org.elasticsearch.search.builder.SearchSourceBuilder;
import org.elasticsearch.search.sort.SortBuilders;
import org.elasticsearch.tasks.TaskId;
import org.elasticsearch.transport.TransportService;
import org.elasticsearch.xcontent.NamedXContentRegistry;
import org.elasticsearch.xcontent.ParseField;
Expand Down Expand Up @@ -73,7 +74,7 @@ protected AbstractTransportGetResourcesAction(
this.xContentRegistry = Objects.requireNonNull(xContentRegistry);
}

protected void searchResources(AbstractGetResourcesRequest request, ActionListener<QueryPage<Resource>> listener) {
protected void searchResources(AbstractGetResourcesRequest request, TaskId parentTaskId, ActionListener<QueryPage<Resource>> listener) {
String[] tokens = Strings.tokenizeToStringArray(request.getResourceId(), ",");
SearchSourceBuilder sourceBuilder = new SearchSourceBuilder().sort(
SortBuilders.fieldSort(request.getResourceIdField())
Expand All @@ -96,6 +97,7 @@ protected void searchResources(AbstractGetResourcesRequest request, ActionListen
indicesOptions
)
).source(customSearchOptions(sourceBuilder));
searchRequest.setParentTask(parentTaskId);

executeAsyncWithOrigin(
client.threadPool().getThreadContext(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@

import java.io.IOException;

import static org.elasticsearch.core.Strings.format;

public class GetDataFrameAnalyticsAction extends ActionType<GetDataFrameAnalyticsAction.Response> {

public static final GetDataFrameAnalyticsAction INSTANCE = new GetDataFrameAnalyticsAction();
Expand Down Expand Up @@ -46,6 +48,11 @@ public Request(StreamInput in) throws IOException {
public String getResourceIdField() {
return DataFrameAnalyticsConfig.ID.getPreferredName();
}

@Override
public String getCancelableTaskDescription() {
return format("get_data_frame_analytics[%s]", getResourceId());
}
}

public static class Response extends AbstractGetResourcesResponse<DataFrameAnalyticsConfig> {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@

import java.io.IOException;

import static org.elasticsearch.core.Strings.format;

public class GetFiltersAction extends ActionType<GetFiltersAction.Response> {

public static final GetFiltersAction INSTANCE = new GetFiltersAction();
Expand All @@ -41,6 +43,11 @@ public Request(StreamInput in) throws IOException {
super(in);
}

@Override
public String getCancelableTaskDescription() {
return format("get_filters[%s]", getResourceId());
}

@Override
public String getResourceIdField() {
return MlFilter.ID.getPreferredName();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@
import java.util.Objects;
import java.util.Set;

import static org.elasticsearch.core.Strings.format;

public class GetTrainedModelsAction extends ActionType<GetTrainedModelsAction.Response> {

public static final GetTrainedModelsAction INSTANCE = new GetTrainedModelsAction();
Expand Down Expand Up @@ -118,7 +120,6 @@ public int hashCode() {
public static class Request extends AbstractGetResourcesRequest {

public static final ParseField INCLUDE = new ParseField("include");
public static final String DEFINITION = "definition";
public static final ParseField ALLOW_NO_MATCH = new ParseField("allow_no_match");
public static final ParseField TAGS = new ParseField("tags");

Expand Down Expand Up @@ -178,6 +179,11 @@ public boolean equals(Object obj) {
Request other = (Request) obj;
return super.equals(obj) && this.includes.equals(other.includes) && Objects.equals(tags, other.tags);
}

@Override
public String getCancelableTaskDescription() {
return format("get_trained_models[%s]", getResourceId());
}
}

public static class Response extends AbstractGetResourcesResponse<TrainedModelConfig> {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
import java.util.Set;

import static org.elasticsearch.core.RestApiVersion.onOrAfter;
import static org.elasticsearch.core.Strings.format;

public class GetTrainedModelsStatsAction extends ActionType<GetTrainedModelsStatsAction.Response> {

Expand Down Expand Up @@ -68,6 +69,11 @@ public Request(StreamInput in) throws IOException {
super(in);
}

@Override
public String getCancelableTaskDescription() {
return format("get_trained_model_stats[%s]", getResourceId());
}

@Override
public String getResourceIdField() {
return TrainedModelConfig.MODEL_ID.getPreferredName();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.core.TimeValue;
import org.elasticsearch.tasks.CancellableTask;
import org.elasticsearch.tasks.Task;
import org.elasticsearch.tasks.TaskId;
import org.elasticsearch.xcontent.ObjectParser;
import org.elasticsearch.xcontent.ParseField;
import org.elasticsearch.xcontent.ToXContentObject;
Expand All @@ -31,6 +34,8 @@
import java.util.Objects;
import java.util.stream.Collectors;

import static org.elasticsearch.core.Strings.format;

public class InferModelAction extends ActionType<InferModelAction.Response> {
public static final String NAME = "cluster:internal/xpack/ml/inference/infer";
public static final String EXTERNAL_NAME = "cluster:monitor/xpack/ml/inference/infer";
Expand Down Expand Up @@ -176,6 +181,11 @@ public boolean equals(Object o) {
&& Objects.equals(objectsToInfer, that.objectsToInfer);
}

@Override
public Task createTask(long id, String type, String action, TaskId parentTaskId, Map<String, String> headers) {
return new CancellableTask(id, type, action, format("infer_trained_model[%s]", modelId), parentTaskId, headers);
}

@Override
public int hashCode() {
return Objects.hash(modelId, objectsToInfer, update, previouslyLicensed, timeout);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,9 @@
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.core.Nullable;
import org.elasticsearch.core.TimeValue;
import org.elasticsearch.tasks.CancellableTask;
import org.elasticsearch.tasks.Task;
import org.elasticsearch.tasks.TaskId;
import org.elasticsearch.xcontent.ObjectParser;
import org.elasticsearch.xcontent.ParseField;
import org.elasticsearch.xcontent.ToXContentObject;
Expand All @@ -36,6 +38,7 @@
import java.util.Optional;

import static org.elasticsearch.action.ValidateActions.addValidationError;
import static org.elasticsearch.core.Strings.format;

public class InferTrainedModelDeploymentAction extends ActionType<InferTrainedModelDeploymentAction.Response> {

Expand Down Expand Up @@ -192,6 +195,11 @@ public int hashCode() {
return Objects.hash(deploymentId, update, docs, inferenceTimeout);
}

@Override
public Task createTask(long id, String type, String action, TaskId parentTaskId, Map<String, String> headers) {
return new CancellableTask(id, type, action, format("infer_trained_model_deployment[%s]", deploymentId), parentTaskId, headers);
}

public static class Builder {

private String deploymentId;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
import java.util.Objects;

import static org.elasticsearch.action.ValidateActions.addValidationError;
import static org.elasticsearch.core.Strings.format;

public class GetTransformAction extends ActionType<GetTransformAction.Response> {

Expand Down Expand Up @@ -76,6 +77,11 @@ public ActionRequestValidationException validate() {
return exception;
}

@Override
public String getCancelableTaskDescription() {
return format("get_transforms[%s]", getResourceId());
}

@Override
public String getResourceIdField() {
return TransformField.ID.getPreferredName();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -110,14 +110,15 @@ public void testStoreModelViaChunkedPersister() throws IOException {
PageParams.defaultParams(),
Collections.emptySet(),
ModelAliasMetadata.EMPTY,
null,
getIdsFuture
);
Tuple<Long, Map<String, Set<String>>> ids = getIdsFuture.actionGet();
assertThat(ids.v1(), equalTo(1L));
String inferenceModelId = ids.v2().keySet().iterator().next();

PlainActionFuture<TrainedModelConfig> getTrainedModelFuture = new PlainActionFuture<>();
trainedModelProvider.getTrainedModel(inferenceModelId, GetTrainedModelsAction.Includes.all(), getTrainedModelFuture);
trainedModelProvider.getTrainedModel(inferenceModelId, GetTrainedModelsAction.Includes.all(), null, getTrainedModelFuture);

TrainedModelConfig storedConfig = getTrainedModelFuture.actionGet();
assertThat(storedConfig.getCompressedDefinition(), equalTo(compressedDefinition));
Expand All @@ -128,7 +129,7 @@ public void testStoreModelViaChunkedPersister() throws IOException {
assertThat(storedConfig.getMetadata(), hasKey("hyperparameters"));

PlainActionFuture<Map<String, TrainedModelMetadata>> getTrainedMetadataFuture = new PlainActionFuture<>();
trainedModelProvider.getTrainedModelMetadata(Collections.singletonList(inferenceModelId), getTrainedMetadataFuture);
trainedModelProvider.getTrainedModelMetadata(Collections.singletonList(inferenceModelId), null, getTrainedMetadataFuture);

TrainedModelMetadata storedMetadata = getTrainedMetadataFuture.actionGet().get(inferenceModelId);
assertThat(storedMetadata.getModelId(), startsWith(modelId));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ public void testGetTrainedModelConfig() throws Exception {

AtomicReference<TrainedModelConfig> getConfigHolder = new AtomicReference<>();
blockingCall(
listener -> trainedModelProvider.getTrainedModel(modelId, GetTrainedModelsAction.Includes.forModelDefinition(), listener),
listener -> trainedModelProvider.getTrainedModel(modelId, GetTrainedModelsAction.Includes.forModelDefinition(), null, listener),
getConfigHolder,
exceptionHolder
);
Expand All @@ -132,7 +132,7 @@ public void testGetTrainedModelConfig() throws Exception {

getConfigHolder = new AtomicReference<>();
blockingCall(
listener -> trainedModelProvider.getTrainedModel(modelId, GetTrainedModelsAction.Includes.all(), listener),
listener -> trainedModelProvider.getTrainedModel(modelId, GetTrainedModelsAction.Includes.all(), null, listener),
getConfigHolder,
exceptionHolder
);
Expand Down Expand Up @@ -204,7 +204,7 @@ public void testGetTrainedModelConfigWithMultiDocDefinition() throws Exception {

AtomicReference<TrainedModelConfig> getConfigHolder = new AtomicReference<>();
blockingCall(
listener -> trainedModelProvider.getTrainedModel(modelId, GetTrainedModelsAction.Includes.forModelDefinition(), listener),
listener -> trainedModelProvider.getTrainedModel(modelId, GetTrainedModelsAction.Includes.forModelDefinition(), null, listener),
getConfigHolder,
exceptionHolder
);
Expand Down Expand Up @@ -248,7 +248,7 @@ public void testGetTrainedModelConfigWithoutDefinition() throws Exception {

AtomicReference<TrainedModelConfig> getConfigHolder = new AtomicReference<>();
blockingCall(
listener -> trainedModelProvider.getTrainedModel(modelId, GetTrainedModelsAction.Includes.empty(), listener),
listener -> trainedModelProvider.getTrainedModel(modelId, GetTrainedModelsAction.Includes.empty(), null, listener),
getConfigHolder,
exceptionHolder
);
Expand All @@ -263,7 +263,7 @@ public void testGetMissingTrainingModelConfig() throws Exception {
AtomicReference<TrainedModelConfig> getConfigHolder = new AtomicReference<>();
AtomicReference<Exception> exceptionHolder = new AtomicReference<>();
blockingCall(
listener -> trainedModelProvider.getTrainedModel(modelId, GetTrainedModelsAction.Includes.forModelDefinition(), listener),
listener -> trainedModelProvider.getTrainedModel(modelId, GetTrainedModelsAction.Includes.forModelDefinition(), null, listener),
getConfigHolder,
exceptionHolder
);
Expand All @@ -288,7 +288,7 @@ public void testGetMissingTrainingModelConfigDefinition() throws Exception {

AtomicReference<TrainedModelConfig> getConfigHolder = new AtomicReference<>();
blockingCall(
listener -> trainedModelProvider.getTrainedModel(modelId, GetTrainedModelsAction.Includes.forModelDefinition(), listener),
listener -> trainedModelProvider.getTrainedModel(modelId, GetTrainedModelsAction.Includes.forModelDefinition(), null, listener),
getConfigHolder,
exceptionHolder
);
Expand Down Expand Up @@ -335,7 +335,7 @@ public void testGetTruncatedModelDeprecatedDefinition() throws Exception {

AtomicReference<TrainedModelConfig> getConfigHolder = new AtomicReference<>();
blockingCall(
listener -> trainedModelProvider.getTrainedModel(modelId, GetTrainedModelsAction.Includes.forModelDefinition(), listener),
listener -> trainedModelProvider.getTrainedModel(modelId, GetTrainedModelsAction.Includes.forModelDefinition(), null, listener),
getConfigHolder,
exceptionHolder
);
Expand Down Expand Up @@ -388,7 +388,7 @@ public void testGetTruncatedModelDefinition() throws Exception {

AtomicReference<TrainedModelConfig> getConfigHolder = new AtomicReference<>();
blockingCall(
listener -> trainedModelProvider.getTrainedModel(modelId, GetTrainedModelsAction.Includes.forModelDefinition(), listener),
listener -> trainedModelProvider.getTrainedModel(modelId, GetTrainedModelsAction.Includes.forModelDefinition(), null, listener),
getConfigHolder,
exceptionHolder
);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,13 @@
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.support.ActionFilters;
import org.elasticsearch.client.internal.Client;
import org.elasticsearch.cluster.service.ClusterService;
import org.elasticsearch.common.inject.Inject;
import org.elasticsearch.core.Nullable;
import org.elasticsearch.index.query.QueryBuilder;
import org.elasticsearch.index.query.QueryBuilders;
import org.elasticsearch.tasks.Task;
import org.elasticsearch.tasks.TaskId;
import org.elasticsearch.transport.TransportService;
import org.elasticsearch.xcontent.NamedXContentRegistry;
import org.elasticsearch.xcontent.ParseField;
Expand All @@ -31,12 +33,14 @@ public class TransportGetDataFrameAnalyticsAction extends AbstractTransportGetRe
DataFrameAnalyticsConfig,
GetDataFrameAnalyticsAction.Request,
GetDataFrameAnalyticsAction.Response> {
private final ClusterService clusterService;

@Inject
public TransportGetDataFrameAnalyticsAction(
TransportService transportService,
ActionFilters actionFilters,
Client client,
ClusterService clusterService,
NamedXContentRegistry xContentRegistry
) {
super(
Expand All @@ -47,6 +51,7 @@ public TransportGetDataFrameAnalyticsAction(
client,
xContentRegistry
);
this.clusterService = clusterService;
}

@Override
Expand Down Expand Up @@ -77,6 +82,7 @@ protected void doExecute(
) {
searchResources(
request,
new TaskId(clusterService.localNode().getId(), task.getId()),
ActionListener.wrap(queryPage -> listener.onResponse(new GetDataFrameAnalyticsAction.Response(queryPage)), listener::onFailure)
);
}
Expand Down
Loading