forked from elastic/elasticsearch
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[ML] add new trained_models/{model_id}/_infer endpoint for all superv…
…ised models and deprecate deployment infer api (elastic#86361) This commit adds a new `_ml/trained_models/{model_id}/_infer` API. This api works for both native NLP models and supervised models trained via Data Frame analytics. The format of the API is the same as the old `_ml/trained_models/{model_id}/deployment/_infer`. Taking a `docs` and an `inference_config` parameter. This PR also deprecates the old experimental `_ml/trained_models/{model_id}/deployment/_infer` API. The biggest difference is that the response now nests all results under an "inference_results" object. closes: elastic#86032
- Loading branch information
Showing
28 changed files
with
1,155 additions
and
110 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
pr: 86361 | ||
summary: "Add new _infer endpoint for all supervised models and deprecate deployment infer api" | ||
area: Machine Learning | ||
type: enhancement | ||
issues: [] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
271 changes: 271 additions & 0 deletions
271
docs/reference/ml/trained-models/apis/infer-trained-model.asciidoc
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,271 @@ | ||
[role="xpack"] | ||
[[infer-trained-model]] | ||
= Infer trained model API | ||
[subs="attributes"] | ||
++++ | ||
<titleabbrev>Infer trained model</titleabbrev> | ||
++++ | ||
|
||
Evaluates a trained model. The model may be any supervised model either trained by {dfanalytics} or imported. | ||
|
||
[[infer-trained-model-request]] | ||
== {api-request-title} | ||
|
||
`POST _ml/trained_models/<model_id>/_infer` | ||
|
||
//// | ||
[[infer-trained-model-prereq]] | ||
== {api-prereq-title} | ||
//// | ||
//// | ||
[[infer-trained-model-desc]] | ||
== {api-description-title} | ||
//// | ||
|
||
[[infer-trained-model-path-params]] | ||
== {api-path-parms-title} | ||
|
||
`<model_id>`:: | ||
(Required, string) | ||
include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=model-id] | ||
|
||
[[infer-trained-model-query-params]] | ||
== {api-query-parms-title} | ||
|
||
`timeout`:: | ||
(Optional, time) | ||
Controls the amount of time to wait for {infer} results. Defaults to 10 seconds. | ||
|
||
[[infer-trained-model-request-body]] | ||
== {api-request-body-title} | ||
|
||
`docs`:: | ||
(Required, array) | ||
An array of objects to pass to the model for inference. The objects should | ||
contain the fields matching your configured trained model input. Typically for NLP models, the field | ||
name is `text_field`. Currently for NLP models, only a single value is allowed. For {dfanalytics} or | ||
imported classification or regression models, more than one value is allowed. | ||
|
||
//// | ||
[[infer-trained-model-results]] | ||
== {api-response-body-title} | ||
//// | ||
//// | ||
[[ml-get-trained-models-response-codes]] | ||
== {api-response-codes-title} | ||
//// | ||
|
||
[[infer-trained-model-example]] | ||
== {api-examples-title} | ||
|
||
The response depends on the kind of model. | ||
|
||
For example, for language identification the response is the predicted language and the score: | ||
|
||
[source,console] | ||
-------------------------------------------------- | ||
POST _ml/trained_models/lang_ident_model_1/_infer | ||
{ | ||
"docs":[{"text": "The fool doth think he is wise, but the wise man knows himself to be a fool."}] | ||
} | ||
-------------------------------------------------- | ||
// TEST[skip:TBD] | ||
|
||
Here are the results predicting english with a high probability. | ||
|
||
[source,console-result] | ||
---- | ||
{ | ||
"inference_results": [ | ||
{ | ||
"predicted_value": "en", | ||
"prediction_probability": 0.9999658805366392, | ||
"prediction_score": 0.9999658805366392 | ||
} | ||
] | ||
} | ||
---- | ||
// NOTCONSOLE | ||
|
||
|
||
When it is a text classification model, the response is the score and predicted classification. | ||
|
||
For example: | ||
|
||
[source,console] | ||
-------------------------------------------------- | ||
POST _ml/trained_models/model2/_infer | ||
{ | ||
"docs": [{"text_field": "The movie was awesome!!"}] | ||
} | ||
-------------------------------------------------- | ||
// TEST[skip:TBD] | ||
|
||
The API returns the predicted label and the confidence. | ||
|
||
[source,console-result] | ||
---- | ||
{ | ||
"inference_results": [{ | ||
"predicted_value" : "POSITIVE", | ||
"prediction_probability" : 0.9998667964092964 | ||
}] | ||
} | ||
---- | ||
// NOTCONSOLE | ||
|
||
For named entity recognition (NER) models, the response contains the annotated | ||
text output and the recognized entities. | ||
|
||
[source,console] | ||
-------------------------------------------------- | ||
POST _ml/trained_models/model2/_infer | ||
{ | ||
"docs": [{"text_field": "Hi my name is Josh and I live in Berlin"}] | ||
} | ||
-------------------------------------------------- | ||
// TEST[skip:TBD] | ||
|
||
The API returns in this case: | ||
|
||
[source,console-result] | ||
---- | ||
{ | ||
"inference_results": [{ | ||
"predicted_value" : "Hi my name is [Josh](PER&Josh) and I live in [Berlin](LOC&Berlin)", | ||
"entities" : [ | ||
{ | ||
"entity" : "Josh", | ||
"class_name" : "PER", | ||
"class_probability" : 0.9977303419824, | ||
"start_pos" : 14, | ||
"end_pos" : 18 | ||
}, | ||
{ | ||
"entity" : "Berlin", | ||
"class_name" : "LOC", | ||
"class_probability" : 0.9992474323902818, | ||
"start_pos" : 33, | ||
"end_pos" : 39 | ||
} | ||
] | ||
}] | ||
} | ||
---- | ||
// NOTCONSOLE | ||
|
||
Zero-shot classification models require extra configuration defining the class labels. | ||
These labels are passed in the zero-shot inference config. | ||
|
||
[source,console] | ||
-------------------------------------------------- | ||
POST _ml/trained_models/model2/_infer | ||
{ | ||
"docs": [ | ||
{ | ||
"text_field": "This is a very happy person" | ||
} | ||
], | ||
"inference_config": { | ||
"zero_shot_classification": { | ||
"labels": [ | ||
"glad", | ||
"sad", | ||
"bad", | ||
"rad" | ||
], | ||
"multi_label": false | ||
} | ||
} | ||
} | ||
-------------------------------------------------- | ||
// TEST[skip:TBD] | ||
|
||
The API returns the predicted label and the confidence, as well as the top classes: | ||
|
||
[source,console-result] | ||
---- | ||
{ | ||
"inference_results": [{ | ||
"predicted_value" : "glad", | ||
"top_classes" : [ | ||
{ | ||
"class_name" : "glad", | ||
"class_probability" : 0.8061155063386439, | ||
"class_score" : 0.8061155063386439 | ||
}, | ||
{ | ||
"class_name" : "rad", | ||
"class_probability" : 0.18218006158387956, | ||
"class_score" : 0.18218006158387956 | ||
}, | ||
{ | ||
"class_name" : "bad", | ||
"class_probability" : 0.006325615787634201, | ||
"class_score" : 0.006325615787634201 | ||
}, | ||
{ | ||
"class_name" : "sad", | ||
"class_probability" : 0.0053788162898424545, | ||
"class_score" : 0.0053788162898424545 | ||
} | ||
], | ||
"prediction_probability" : 0.8061155063386439 | ||
}] | ||
} | ||
---- | ||
// NOTCONSOLE | ||
|
||
|
||
The tokenization truncate option can be overridden when calling the API: | ||
|
||
[source,console] | ||
-------------------------------------------------- | ||
POST _ml/trained_models/model2/_infer | ||
{ | ||
"docs": [{"text_field": "The Amazon rainforest covers most of the Amazon basin in South America"}], | ||
"inference_config": { | ||
"ner": { | ||
"tokenization": { | ||
"bert": { | ||
"truncate": "first" | ||
} | ||
} | ||
} | ||
} | ||
} | ||
-------------------------------------------------- | ||
// TEST[skip:TBD] | ||
|
||
When the input has been truncated due to the limit imposed by the model's `max_sequence_length` | ||
the `is_truncated` field appears in the response. | ||
|
||
[source,console-result] | ||
---- | ||
{ | ||
"inference_results": [{ | ||
"predicted_value" : "The [Amazon](LOC&Amazon) rainforest covers most of the [Amazon](LOC&Amazon) basin in [South America](LOC&South+America)", | ||
"entities" : [ | ||
{ | ||
"entity" : "Amazon", | ||
"class_name" : "LOC", | ||
"class_probability" : 0.9505460915724254, | ||
"start_pos" : 4, | ||
"end_pos" : 10 | ||
}, | ||
{ | ||
"entity" : "Amazon", | ||
"class_name" : "LOC", | ||
"class_probability" : 0.9969992804311777, | ||
"start_pos" : 41, | ||
"end_pos" : 47 | ||
} | ||
], | ||
"is_truncated" : true | ||
}] | ||
} | ||
---- | ||
// NOTCONSOLE |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.