Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Check for ml privilege when using the Inference Aggregation #59530

Merged
merged 5 commits into from
Jul 14, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions x-pack/plugin/ml/qa/ml-with-security/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,7 @@ integTest.runner {
'ml/job_groups/Test put job with id that matches an existing group',
'ml/job_groups/Test put job with invalid group',
'ml/ml_info/Test ml info',
'ml/pipeline_inference/Test setting results field is invalid',
'ml/post_data/Test Flush data with invalid parameters',
'ml/post_data/Test flushing and posting a closed job',
'ml/post_data/Test open and close with non-existent job id',
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@
import org.elasticsearch.test.rest.yaml.section.ExecutableSection;

import java.io.IOException;
import java.util.List;
import java.util.Map;
import java.util.Set;

import static org.hamcrest.Matchers.containsString;
import static org.hamcrest.Matchers.either;
Expand All @@ -26,6 +29,7 @@ public MlWithSecurityInsufficientRoleIT(@Name("yaml") ClientYamlTestCandidate te
}

@Override
@SuppressWarnings("unchecked")
public void test() throws IOException {
try {
// Cannot use expectThrows here because blacklisted tests will throw an
Expand All @@ -38,7 +42,19 @@ public void test() throws IOException {
String apiName = ((DoSection) section).getApiCallSection().getApi();

if (apiName.startsWith("ml.")) {
fail("call to ml endpoint should have failed because of missing role");
fail("call to ml endpoint [" + apiName + "] should have failed because of missing role");
} else if (apiName.startsWith("search")) {
DoSection doSection = (DoSection) section;
List<Map<String, Object>> bodies = doSection.getApiCallSection().getBodies();
boolean containsInferenceAgg = false;
for (Map<String, Object> body : bodies) {
Map<String, Object> aggs = (Map<String, Object>)body.get("aggs");
containsInferenceAgg = containsInferenceAgg || containsKey("inference", aggs);
}

if (containsInferenceAgg) {
fail("call to [search] with the ml inference agg should have failed because of missing role");
}
}
}
}
Expand All @@ -49,9 +65,13 @@ public void test() throws IOException {
assertThat(ae.getMessage(), containsString("but was Integer [0]"));
} else {
assertThat(ae.getMessage(),
either(containsString("action [cluster:monitor/xpack/ml")).or(containsString("action [cluster:admin/xpack/ml")));
either(containsString("action [cluster:monitor/xpack/ml"))
.or(containsString("action [cluster:admin/xpack/ml"))
.or(containsString("security_exception")));
assertThat(ae.getMessage(), containsString("returned [403 Forbidden]"));
assertThat(ae.getMessage(), containsString("is unauthorized for user [no_ml]"));
assertThat(ae.getMessage(),
either(containsString("is unauthorized for user [no_ml]"))
.or(containsString("user [no_ml] does not have the privilege to get trained models")));
}
}
}
Expand All @@ -60,5 +80,24 @@ public void test() throws IOException {
protected String[] getCredentials() {
return new String[]{"no_ml", "x-pack-test-password"};
}

@SuppressWarnings("unchecked")
static boolean containsKey(String key, Map<String, Object> mapOfMaps) {
if (mapOfMaps.containsKey(key)) {
return true;
}

Set<Map.Entry<String, Object>> entries = mapOfMaps.entrySet();
for (Map.Entry<String, Object> entry : entries) {
if (entry.getValue() instanceof Map<?,?>) {
boolean isInNestedMap = containsKey(key, (Map<String, Object>)entry.getValue());
if (isInNestedMap) {
return true;
}
}
}

return false;
}
}

Original file line number Diff line number Diff line change
Expand Up @@ -48,13 +48,14 @@ public void test() throws IOException {
String apiName = ((DoSection) section).getApiCallSection().getApi();

if (apiName.startsWith("ml.") && isAllowed(apiName) == false) {
fail("should have failed because of missing role");
fail("call to ml endpoint [" + apiName + "] should have failed because of missing role");
}
}
}
} catch (AssertionError ae) {
assertThat(ae.getMessage(),
either(containsString("action [cluster:monitor/xpack/ml")).or(containsString("action [cluster:admin/xpack/ml")));
either(containsString("action [cluster:monitor/xpack/ml"))
.or(containsString("action [cluster:admin/xpack/ml")));
assertThat(ae.getMessage(), containsString("returned [403 Forbidden]"));
assertThat(ae.getMessage(), containsString("is unauthorized for user [ml_user]"));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,12 @@

import org.apache.lucene.util.SetOnce;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.client.Client;
import org.elasticsearch.common.ParseField;
import org.elasticsearch.common.Strings;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.common.xcontent.XContentParser;
Expand All @@ -21,20 +23,29 @@
import org.elasticsearch.search.aggregations.pipeline.AbstractPipelineAggregationBuilder;
import org.elasticsearch.search.aggregations.pipeline.PipelineAggregator;
import org.elasticsearch.xpack.core.XPackField;
import org.elasticsearch.xpack.core.ml.action.GetTrainedModelsAction;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfigUpdate;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfigUpdate;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ResultsFieldUpdate;
import org.elasticsearch.xpack.core.security.SecurityContext;
import org.elasticsearch.xpack.core.security.action.user.HasPrivilegesAction;
import org.elasticsearch.xpack.core.security.action.user.HasPrivilegesRequest;
import org.elasticsearch.xpack.core.security.action.user.HasPrivilegesResponse;
import org.elasticsearch.xpack.core.security.authz.RoleDescriptor;
import org.elasticsearch.xpack.core.security.support.Exceptions;
import org.elasticsearch.xpack.ml.inference.loadingservice.LocalModel;
import org.elasticsearch.xpack.ml.inference.loadingservice.ModelLoadingService;

import java.io.IOException;
import java.util.Map;
import java.util.Objects;
import java.util.TreeMap;
import java.util.function.BiConsumer;
import java.util.function.Supplier;

import static org.elasticsearch.common.xcontent.ConstructingObjectParser.constructorArg;
import static org.elasticsearch.xpack.ml.utils.SecondaryAuthorizationUtils.useSecondaryAuthIfAvailable;

public class InferencePipelineAggregationBuilder extends AbstractPipelineAggregationBuilder<InferencePipelineAggregationBuilder> {

Expand Down Expand Up @@ -186,8 +197,9 @@ public InferencePipelineAggregationBuilder rewrite(QueryRewriteContext context)
if (model != null) {
return this;
}

SetOnce<LocalModel> loadedModel = new SetOnce<>();
context.registerAsyncAction((client, listener) -> {
BiConsumer<Client, ActionListener<?>> modelLoadAction = (client, listener) ->
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, nice! That's what to use when in need to return more than a single thing. I'll use that onwards!

modelLoadingService.get().getModelForSearch(modelId, ActionListener.delegateFailure(listener, (delegate, model) -> {
loadedModel.set(model);

Expand All @@ -199,6 +211,36 @@ public InferencePipelineAggregationBuilder rewrite(QueryRewriteContext context)
delegate.onFailure(LicenseUtils.newComplianceException(XPackField.MACHINE_LEARNING));
}
}));


context.registerAsyncAction((client, listener) -> {
if (licenseState.isSecurityEnabled()) {
// check the user has ml privileges
SecurityContext securityContext = new SecurityContext(Settings.EMPTY, client.threadPool().getThreadContext());
useSecondaryAuthIfAvailable(securityContext, () -> {
final String username = securityContext.getUser().principal();
final HasPrivilegesRequest privRequest = new HasPrivilegesRequest();
privRequest.username(username);
privRequest.clusterPrivileges(GetTrainedModelsAction.NAME);
privRequest.indexPrivileges(new RoleDescriptor.IndicesPrivileges[]{});
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need to set index and application privileges even though they're empty?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

privRequest.applicationPrivileges(new RoleDescriptor.ApplicationResourcePrivileges[]{});

ActionListener<HasPrivilegesResponse> privResponseListener = ActionListener.wrap(
r -> {
if (r.isCompleteMatch()) {
modelLoadAction.accept(client, listener);
} else {
listener.onFailure(Exceptions.authorizationError("user [" + username
+ "] does not have the privilege to get trained models so cannot use ml inference"));
}
},
listener::onFailure);

client.execute(HasPrivilegesAction.INSTANCE, privRequest, privResponseListener);
});
} else {
modelLoadAction.accept(client, listener);
}
});
return new InferencePipelineAggregationBuilder(name, bucketPathMap, loadedModel::get, modelId, inferenceConfig, licenseState);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,7 @@ setup:
}
}
- match: { aggregations.good.buckets.0.regression_agg.value: 2.0 }

---
"Test pipeline agg referencing a single bucket":

Expand Down