Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ML] Add hyperparameter model metadata #66349

Merged
merged 21 commits into from
Dec 18, 2020
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
d575513
init commit
valeriy42 Dec 14, 2020
e244d66
continue adjusting tests
valeriy42 Dec 15, 2020
f6236fb
fix ChunkedTrainedModelPersisterIT.java
valeriy42 Dec 15, 2020
48a3645
formatting
valeriy42 Dec 16, 2020
d28b824
rename hyperparameters, add supplied
valeriy42 Dec 16, 2020
ff27fab
code review
valeriy42 Dec 18, 2020
0c8e00b
Update x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/data…
valeriy42 Dec 18, 2020
f73ec7f
Update x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/…
valeriy42 Dec 18, 2020
f301ac7
Update x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/…
valeriy42 Dec 18, 2020
5515700
optional fields
valeriy42 Dec 18, 2020
b4de908
Merge branch 'hyperparameter-model-metadata' of https://github.com/va…
valeriy42 Dec 18, 2020
33b8af7
handle optional importance fields
valeriy42 Dec 18, 2020
f14592b
formatting
valeriy42 Dec 18, 2020
0399e10
Update x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/…
valeriy42 Dec 18, 2020
8be327d
Update x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/…
valeriy42 Dec 18, 2020
be37937
Update x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/…
valeriy42 Dec 18, 2020
f9e6ef2
Update x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/…
valeriy42 Dec 18, 2020
62dab59
Update x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/…
valeriy42 Dec 18, 2020
3a5182d
Update x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/…
valeriy42 Dec 18, 2020
cbf0dc8
fix dangling comma
valeriy42 Dec 18, 2020
43dbb5b
Merge branch 'master' into hyperparameter-model-metadata
elasticmachine Dec 18, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -39,12 +39,15 @@ public static class Includes implements Writeable {
static final String DEFINITION = "definition";
static final String TOTAL_FEATURE_IMPORTANCE = "total_feature_importance";
static final String FEATURE_IMPORTANCE_BASELINE = "feature_importance_baseline";
static final String HYPERPARAMETERS = "hyperparameters";

private static final Set<String> KNOWN_INCLUDES;
static {
HashSet<String> includes = new HashSet<>(3, 1.0f);
HashSet<String> includes = new HashSet<>(4, 1.0f);
includes.add(DEFINITION);
includes.add(TOTAL_FEATURE_IMPORTANCE);
includes.add(FEATURE_IMPORTANCE_BASELINE);
includes.add(HYPERPARAMETERS);
KNOWN_INCLUDES = Collections.unmodifiableSet(includes);
}

Expand Down Expand Up @@ -94,6 +97,10 @@ public boolean isIncludeFeatureImportanceBaseline() {
return this.includes.contains(FEATURE_IMPORTANCE_BASELINE);
}

public boolean isIncludeHyperparameters() {
return this.includes.contains(HYPERPARAMETERS);
}

@Override
public boolean equals(Object o) {
if (this == o) return true;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.StrictlyParsedInferenceConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.metadata.FeatureImportanceBaseline;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.metadata.TotalFeatureImportance;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.metadata.Hyperparameters;
import org.elasticsearch.xpack.core.ml.job.messages.Messages;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
import org.elasticsearch.xpack.core.ml.utils.MlStrings;
Expand Down Expand Up @@ -57,6 +58,7 @@ public class TrainedModelConfig implements ToXContentObject, Writeable {
public static final String DECOMPRESS_DEFINITION = "decompress_definition";
public static final String TOTAL_FEATURE_IMPORTANCE = "total_feature_importance";
public static final String FEATURE_IMPORTANCE_BASELINE = "feature_importance_baseline";
public static final String HYPERPARAMETERS = "hyperparameters";
private static final Set<String> RESERVED_METADATA_FIELDS = new HashSet<>(Arrays.asList(
TOTAL_FEATURE_IMPORTANCE,
FEATURE_IMPORTANCE_BASELINE));
Expand Down Expand Up @@ -492,6 +494,18 @@ public Builder setBaselineFeatureImportance(FeatureImportanceBaseline featureImp
return this;
}

public Builder setHyperparameters(List<Hyperparameters> hyperparameters) {
if (hyperparameters == null) {
return this;
}
if (this.metadata == null) {
this.metadata = new HashMap<>();
}
this.metadata.put(HYPERPARAMETERS,
hyperparameters.stream().map(Hyperparameters::asMap).collect(Collectors.toList()));
return this;
}

public Builder setParsedDefinition(TrainedModelDefinition.Builder definition) {
if (definition == null) {
return this;
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/

package org.elasticsearch.xpack.core.ml.inference.trainedmodel.metadata;

import org.elasticsearch.common.ParseField;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
import org.elasticsearch.common.xcontent.ToXContentObject;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.common.xcontent.XContentParser;

import java.io.IOException;
import java.util.LinkedHashMap;
import java.util.Map;
import java.util.Objects;

public class Hyperparameters implements ToXContentObject, Writeable {

private static final String NAME = "hyperparameters";
public static final ParseField HYPERPARAMETER_NAME = new ParseField("name");
public static final ParseField VALUE = new ParseField("value");
public static final ParseField ABSOLUTE_IMPORTANCE = new ParseField("absolute_importance");
public static final ParseField RELATIVE_IMPORTANCE = new ParseField("relative_importance");
public static final ParseField SUPPLIED = new ParseField("supplied");


// These parsers follow the pattern that metadata is parsed leniently (to allow for enhancements), whilst config is parsed strictly
public static final ConstructingObjectParser<Hyperparameters, Void> LENIENT_PARSER = createParser(true);
public static final ConstructingObjectParser<Hyperparameters, Void> STRICT_PARSER = createParser(false);

@SuppressWarnings("unchecked")
private static ConstructingObjectParser<Hyperparameters, Void> createParser(boolean ignoreUnknownFields) {
ConstructingObjectParser<Hyperparameters, Void> parser = new ConstructingObjectParser<>(NAME,
ignoreUnknownFields,
a -> new Hyperparameters((String)a[0], (Double)a[1], (Double)a[2], (Double)a[3], (Boolean)a[4]));
parser.declareString(ConstructingObjectParser.constructorArg(), HYPERPARAMETER_NAME);
parser.declareDouble(ConstructingObjectParser.constructorArg(), VALUE);
parser.declareDouble(ConstructingObjectParser.constructorArg(), ABSOLUTE_IMPORTANCE);
parser.declareDouble(ConstructingObjectParser.constructorArg(), RELATIVE_IMPORTANCE);
parser.declareBoolean(ConstructingObjectParser.constructorArg(), SUPPLIED);
return parser;
}

public static Hyperparameters fromXContent(XContentParser parser, boolean lenient) throws IOException {
return lenient ? LENIENT_PARSER.parse(parser, null) : STRICT_PARSER.parse(parser, null);
}

public final String hyperparameterName;
public final Double value;
public final Double absoluteImportance;
public final Double relativeImportance;
public final Boolean supplied;

public Hyperparameters(StreamInput in) throws IOException {
this.hyperparameterName = in.readString();
this.value = in.readDouble();
this.absoluteImportance = in.readDouble();
this.relativeImportance = in.readDouble();
this.supplied = in.readBoolean();
}

Hyperparameters(String hyperparameterName, Double value, Double absoluteImportance, Double relativeImportance, Boolean supplied) {
this.hyperparameterName = hyperparameterName;
this.value = value;
this.absoluteImportance = absoluteImportance;
this.relativeImportance = relativeImportance;
this.supplied = supplied;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If the Doubles and the Boolean are truly not nullable, they should be double and boolean (unboxed, not nullable).

If they are nullable, the parser needs to treat them as optional and the serialization needs to as well

}

@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeString(hyperparameterName);
out.writeDouble(value);
out.writeDouble(absoluteImportance);
out.writeDouble(relativeImportance);
out.writeBoolean(supplied);
}

@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
return builder.map(asMap());
}

@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
Hyperparameters that = (Hyperparameters) o;
return Objects.equals(that.hyperparameterName, hyperparameterName)
&& Objects.equals(value, that.value)
&& Objects.equals(absoluteImportance, that.absoluteImportance)
&& Objects.equals(relativeImportance, that.relativeImportance)
&& Objects.equals(supplied, that.supplied)
;
}

public Map<String, Object> asMap() {
Map<String, Object> map = new LinkedHashMap<>();
map.put(HYPERPARAMETER_NAME.getPreferredName(), hyperparameterName);
map.put(VALUE.getPreferredName(), value);
map.put(ABSOLUTE_IMPORTANCE.getPreferredName(), absoluteImportance);
map.put(RELATIVE_IMPORTANCE.getPreferredName(), relativeImportance);
map.put(SUPPLIED.getPreferredName(), supplied);

return map;
}

@Override
public int hashCode() {
return Objects.hash(hyperparameterName, value, absoluteImportance, relativeImportance, supplied);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ public class TrainedModelMetadata implements ToXContentObject, Writeable {

public static final String NAME = "trained_model_metadata";
public static final ParseField TOTAL_FEATURE_IMPORTANCE = new ParseField("total_feature_importance");
public static final ParseField HYPERPARAMETERS = new ParseField("hyperparameters");
public static final ParseField FEATURE_IMPORTANCE_BASELINE = new ParseField("feature_importance_baseline");
public static final ParseField MODEL_ID = new ParseField("model_id");

Expand All @@ -38,14 +39,18 @@ public class TrainedModelMetadata implements ToXContentObject, Writeable {
private static ConstructingObjectParser<TrainedModelMetadata, Void> createParser(boolean ignoreUnknownFields) {
ConstructingObjectParser<TrainedModelMetadata, Void> parser = new ConstructingObjectParser<>(NAME,
ignoreUnknownFields,
a -> new TrainedModelMetadata((String)a[0], (List<TotalFeatureImportance>)a[1], (FeatureImportanceBaseline)a[2]));
a -> new TrainedModelMetadata((String)a[0], (List<TotalFeatureImportance>)a[1], (FeatureImportanceBaseline)a[2],
(List<Hyperparameters>)a[3]));
parser.declareString(ConstructingObjectParser.constructorArg(), MODEL_ID);
parser.declareObjectArray(ConstructingObjectParser.constructorArg(),
ignoreUnknownFields ? TotalFeatureImportance.LENIENT_PARSER : TotalFeatureImportance.STRICT_PARSER,
TOTAL_FEATURE_IMPORTANCE);
parser.declareObject(ConstructingObjectParser.optionalConstructorArg(),
ignoreUnknownFields ? FeatureImportanceBaseline.LENIENT_PARSER : FeatureImportanceBaseline.STRICT_PARSER,
FEATURE_IMPORTANCE_BASELINE);
parser.declareObjectArray(ConstructingObjectParser.optionalConstructorArg(),
ignoreUnknownFields ? Hyperparameters.LENIENT_PARSER : Hyperparameters.STRICT_PARSER,
HYPERPARAMETERS);
return parser;
}

Expand All @@ -63,20 +68,25 @@ public static String modelId(String docId) {

private final List<TotalFeatureImportance> totalFeatureImportances;
private final FeatureImportanceBaseline featureImportanceBaselines;
private final List<Hyperparameters> hyperparameters;
private final String modelId;

public TrainedModelMetadata(StreamInput in) throws IOException {
this.modelId = in.readString();
this.totalFeatureImportances = in.readList(TotalFeatureImportance::new);
this.featureImportanceBaselines = in.readOptionalWriteable(FeatureImportanceBaseline::new);
this.hyperparameters = in.readList(Hyperparameters::new);
}

public TrainedModelMetadata(String modelId,
List<TotalFeatureImportance> totalFeatureImportances,
FeatureImportanceBaseline featureImportanceBaselines) {
FeatureImportanceBaseline featureImportanceBaselines,
List<Hyperparameters> hyperparameters) {
this.modelId = ExceptionsHelper.requireNonNull(modelId, MODEL_ID);
this.totalFeatureImportances = Collections.unmodifiableList(totalFeatureImportances);
this.featureImportanceBaselines = featureImportanceBaselines;
this.hyperparameters = hyperparameters == null ? Collections.emptyList() : Collections.unmodifiableList(hyperparameters);

valeriy42 marked this conversation as resolved.
Show resolved Hide resolved
}

public String getModelId() {
Expand All @@ -95,26 +105,33 @@ public FeatureImportanceBaseline getFeatureImportanceBaselines() {
return featureImportanceBaselines;
}

public List<Hyperparameters> getHyperparameters() {
return hyperparameters;
}

@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
TrainedModelMetadata that = (TrainedModelMetadata) o;
return Objects.equals(totalFeatureImportances, that.totalFeatureImportances) &&
Objects.equals(featureImportanceBaselines, that.featureImportanceBaselines) &&
Objects.equals(hyperparameters, that.hyperparameters) &&
Objects.equals(modelId, that.modelId);
}

@Override
public int hashCode() {
return Objects.hash(totalFeatureImportances, featureImportanceBaselines, modelId);
return Objects.hash(totalFeatureImportances, featureImportanceBaselines, hyperparameters, modelId);
}

@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeString(modelId);
out.writeList(totalFeatureImportances);
out.writeOptionalWriteable(featureImportanceBaselines);
out.writeList(hyperparameters);

valeriy42 marked this conversation as resolved.
Show resolved Hide resolved
}

@Override
Expand All @@ -128,6 +145,7 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
if (featureImportanceBaselines != null) {
builder.field(FEATURE_IMPORTANCE_BASELINE.getPreferredName(), featureImportanceBaselines);
}
builder.field(HYPERPARAMETERS.getPreferredName(), hyperparameters);
builder.endObject();
return builder;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,27 @@
}
}
}
},
"hyperparameters": {
"type": "nested",
"dynamic": "false",
"properties": {
"name": {
"type": "keyword"
},
"value": {
"type": "double"
},
"absolute_importance": {
"type": "double"
},
"relative_importance": {
"type": "double"
},
"supplied": {
"type": "boolean"
},

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

^ dangling comma

}
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,8 @@ protected Request createTestInstance() {
randomBoolean() ? null :
Stream.generate(() -> randomFrom(Includes.DEFINITION,
Includes.TOTAL_FEATURE_IMPORTANCE,
Includes.FEATURE_IMPORTANCE_BASELINE))
Includes.FEATURE_IMPORTANCE_BASELINE,
Includes.HYPERPARAMETERS))
.limit(4)
.collect(Collectors.toSet()));
request.setPageParams(new PageParams(randomIntBetween(0, 100), randomIntBetween(0, 100)));
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.core.ml.inference.trainedmodel.metadata;

import org.elasticsearch.Version;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.xpack.core.ml.AbstractBWCSerializationTestCase;
import org.junit.Before;

import java.io.IOException;


public class HyperparametersTests extends AbstractBWCSerializationTestCase<Hyperparameters> {

private boolean lenient;

@SuppressWarnings("unchecked")
public static Hyperparameters randomInstance() {
return new Hyperparameters(
valeriy42 marked this conversation as resolved.
Show resolved Hide resolved
randomAlphaOfLength(10),
randomDoubleBetween(0.0, 1.0, true),
randomDoubleBetween(0.0, 100.0, true),
randomDoubleBetween(0.0, 1.0, true),
randomBoolean());
valeriy42 marked this conversation as resolved.
Show resolved Hide resolved
}

@Before
public void chooseStrictOrLenient() {
lenient = randomBoolean();
}

@Override
protected Hyperparameters createTestInstance() {
return randomInstance();
}

@Override
protected Writeable.Reader<Hyperparameters> instanceReader() {
return Hyperparameters::new;
}

@Override
protected Hyperparameters doParseInstance(XContentParser parser) throws IOException {
return Hyperparameters.fromXContent(parser, lenient);
}

@Override
protected boolean supportsUnknownFields() {
return lenient;
}

@Override
protected Hyperparameters mutateInstanceForVersion(Hyperparameters instance, Version version) {
return instance;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@ public static TrainedModelMetadata randomInstance() {
return new TrainedModelMetadata(
randomAlphaOfLength(10),
Stream.generate(TotalFeatureImportanceTests::randomInstance).limit(randomIntBetween(1, 10)).collect(Collectors.toList()),
randomBoolean() ? null : FeatureImportanceBaselineTests.randomInstance());
randomBoolean() ? null : FeatureImportanceBaselineTests.randomInstance(),
randomBoolean() ? null : Stream.generate(HyperparametersTests::randomInstance).limit(randomIntBetween(1, 10)).collect(Collectors.toList()));
}

@Before
Expand Down
Loading