Skip to content

Commit

Permalink
Check for ml privilege when using the Inference Aggregation (#59530) (#…
Browse files Browse the repository at this point in the history
…59562)

The inference pipeline aggregation requires the user has permission to access
the ml get trained models endpoint (_ml/inference/)
  • Loading branch information
davidkyle authored Jul 14, 2020
1 parent 408a07f commit 0d2ea1b
Show file tree
Hide file tree
Showing 5 changed files with 90 additions and 6 deletions.
1 change: 1 addition & 0 deletions x-pack/plugin/ml/qa/ml-with-security/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,7 @@ integTest.runner {
'ml/job_groups/Test put job with id that matches an existing group',
'ml/job_groups/Test put job with invalid group',
'ml/ml_info/Test ml info',
'ml/pipeline_inference/Test setting results field is invalid',
'ml/post_data/Test Flush data with invalid parameters',
'ml/post_data/Test flushing and posting a closed job',
'ml/post_data/Test open and close with non-existent job id',
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@
import org.elasticsearch.test.rest.yaml.section.ExecutableSection;

import java.io.IOException;
import java.util.List;
import java.util.Map;
import java.util.Set;

import static org.hamcrest.Matchers.containsString;
import static org.hamcrest.Matchers.either;
Expand All @@ -26,6 +29,7 @@ public MlWithSecurityInsufficientRoleIT(@Name("yaml") ClientYamlTestCandidate te
}

@Override
@SuppressWarnings("unchecked")
public void test() throws IOException {
try {
// Cannot use expectThrows here because blacklisted tests will throw an
Expand All @@ -38,7 +42,19 @@ public void test() throws IOException {
String apiName = ((DoSection) section).getApiCallSection().getApi();

if (apiName.startsWith("ml.")) {
fail("call to ml endpoint should have failed because of missing role");
fail("call to ml endpoint [" + apiName + "] should have failed because of missing role");
} else if (apiName.startsWith("search")) {
DoSection doSection = (DoSection) section;
List<Map<String, Object>> bodies = doSection.getApiCallSection().getBodies();
boolean containsInferenceAgg = false;
for (Map<String, Object> body : bodies) {
Map<String, Object> aggs = (Map<String, Object>)body.get("aggs");
containsInferenceAgg = containsInferenceAgg || containsKey("inference", aggs);
}

if (containsInferenceAgg) {
fail("call to [search] with the ml inference agg should have failed because of missing role");
}
}
}
}
Expand All @@ -49,9 +65,13 @@ public void test() throws IOException {
assertThat(ae.getMessage(), containsString("but was Integer [0]"));
} else {
assertThat(ae.getMessage(),
either(containsString("action [cluster:monitor/xpack/ml")).or(containsString("action [cluster:admin/xpack/ml")));
either(containsString("action [cluster:monitor/xpack/ml"))
.or(containsString("action [cluster:admin/xpack/ml"))
.or(containsString("security_exception")));
assertThat(ae.getMessage(), containsString("returned [403 Forbidden]"));
assertThat(ae.getMessage(), containsString("is unauthorized for user [no_ml]"));
assertThat(ae.getMessage(),
either(containsString("is unauthorized for user [no_ml]"))
.or(containsString("user [no_ml] does not have the privilege to get trained models")));
}
}
}
Expand All @@ -60,5 +80,24 @@ public void test() throws IOException {
protected String[] getCredentials() {
return new String[]{"no_ml", "x-pack-test-password"};
}

@SuppressWarnings("unchecked")
static boolean containsKey(String key, Map<String, Object> mapOfMaps) {
if (mapOfMaps.containsKey(key)) {
return true;
}

Set<Map.Entry<String, Object>> entries = mapOfMaps.entrySet();
for (Map.Entry<String, Object> entry : entries) {
if (entry.getValue() instanceof Map<?,?>) {
boolean isInNestedMap = containsKey(key, (Map<String, Object>)entry.getValue());
if (isInNestedMap) {
return true;
}
}
}

return false;
}
}

Original file line number Diff line number Diff line change
Expand Up @@ -48,13 +48,14 @@ public void test() throws IOException {
String apiName = ((DoSection) section).getApiCallSection().getApi();

if (apiName.startsWith("ml.") && isAllowed(apiName) == false) {
fail("should have failed because of missing role");
fail("call to ml endpoint [" + apiName + "] should have failed because of missing role");
}
}
}
} catch (AssertionError ae) {
assertThat(ae.getMessage(),
either(containsString("action [cluster:monitor/xpack/ml")).or(containsString("action [cluster:admin/xpack/ml")));
either(containsString("action [cluster:monitor/xpack/ml"))
.or(containsString("action [cluster:admin/xpack/ml")));
assertThat(ae.getMessage(), containsString("returned [403 Forbidden]"));
assertThat(ae.getMessage(), containsString("is unauthorized for user [ml_user]"));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,12 @@

import org.apache.lucene.util.SetOnce;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.client.Client;
import org.elasticsearch.common.ParseField;
import org.elasticsearch.common.Strings;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.common.xcontent.XContentParser;
Expand All @@ -21,20 +23,29 @@
import org.elasticsearch.search.aggregations.pipeline.AbstractPipelineAggregationBuilder;
import org.elasticsearch.search.aggregations.pipeline.PipelineAggregator;
import org.elasticsearch.xpack.core.XPackField;
import org.elasticsearch.xpack.core.ml.action.GetTrainedModelsAction;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfigUpdate;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfigUpdate;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ResultsFieldUpdate;
import org.elasticsearch.xpack.core.security.SecurityContext;
import org.elasticsearch.xpack.core.security.action.user.HasPrivilegesAction;
import org.elasticsearch.xpack.core.security.action.user.HasPrivilegesRequest;
import org.elasticsearch.xpack.core.security.action.user.HasPrivilegesResponse;
import org.elasticsearch.xpack.core.security.authz.RoleDescriptor;
import org.elasticsearch.xpack.core.security.support.Exceptions;
import org.elasticsearch.xpack.ml.inference.loadingservice.LocalModel;
import org.elasticsearch.xpack.ml.inference.loadingservice.ModelLoadingService;

import java.io.IOException;
import java.util.Map;
import java.util.Objects;
import java.util.TreeMap;
import java.util.function.BiConsumer;
import java.util.function.Supplier;

import static org.elasticsearch.common.xcontent.ConstructingObjectParser.constructorArg;
import static org.elasticsearch.xpack.ml.utils.SecondaryAuthorizationUtils.useSecondaryAuthIfAvailable;

public class InferencePipelineAggregationBuilder extends AbstractPipelineAggregationBuilder<InferencePipelineAggregationBuilder> {

Expand Down Expand Up @@ -186,8 +197,9 @@ public InferencePipelineAggregationBuilder rewrite(QueryRewriteContext context)
if (model != null) {
return this;
}

SetOnce<LocalModel> loadedModel = new SetOnce<>();
context.registerAsyncAction((client, listener) -> {
BiConsumer<Client, ActionListener<?>> modelLoadAction = (client, listener) ->
modelLoadingService.get().getModelForSearch(modelId, ActionListener.delegateFailure(listener, (delegate, model) -> {
loadedModel.set(model);

Expand All @@ -199,6 +211,36 @@ public InferencePipelineAggregationBuilder rewrite(QueryRewriteContext context)
delegate.onFailure(LicenseUtils.newComplianceException(XPackField.MACHINE_LEARNING));
}
}));


context.registerAsyncAction((client, listener) -> {
if (licenseState.isSecurityEnabled()) {
// check the user has ml privileges
SecurityContext securityContext = new SecurityContext(Settings.EMPTY, client.threadPool().getThreadContext());
useSecondaryAuthIfAvailable(securityContext, () -> {
final String username = securityContext.getUser().principal();
final HasPrivilegesRequest privRequest = new HasPrivilegesRequest();
privRequest.username(username);
privRequest.clusterPrivileges(GetTrainedModelsAction.NAME);
privRequest.indexPrivileges(new RoleDescriptor.IndicesPrivileges[]{});
privRequest.applicationPrivileges(new RoleDescriptor.ApplicationResourcePrivileges[]{});

ActionListener<HasPrivilegesResponse> privResponseListener = ActionListener.wrap(
r -> {
if (r.isCompleteMatch()) {
modelLoadAction.accept(client, listener);
} else {
listener.onFailure(Exceptions.authorizationError("user [" + username
+ "] does not have the privilege to get trained models so cannot use ml inference"));
}
},
listener::onFailure);

client.execute(HasPrivilegesAction.INSTANCE, privRequest, privResponseListener);
});
} else {
modelLoadAction.accept(client, listener);
}
});
return new InferencePipelineAggregationBuilder(name, bucketPathMap, loadedModel::get, modelId, inferenceConfig, licenseState);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,7 @@ setup:
}
}
- match: { aggregations.good.buckets.0.regression_agg.value: 2.0 }

---
"Test pipeline agg referencing a single bucket":

Expand Down

0 comments on commit 0d2ea1b

Please sign in to comment.