Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ML] Enable built-in Inference Endpoints and default for Semantic Text #116931

Merged
merged 5 commits into from
Nov 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions docs/changelog/116931.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
pr: 116931
summary: Enable built-in Inference Endpoints and default for Semantic Text
area: "Machine Learning"
type: enhancement
issues: []
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ public enum FeatureFlag {
TIME_SERIES_MODE("es.index_mode_feature_flag_registered=true", Version.fromString("8.0.0"), null),
FAILURE_STORE_ENABLED("es.failure_store_feature_flag_enabled=true", Version.fromString("8.12.0"), null),
SUB_OBJECTS_AUTO_ENABLED("es.sub_objects_auto_feature_flag_enabled=true", Version.fromString("8.16.0"), null),
INFERENCE_DEFAULT_ELSER("es.inference_default_elser_feature_flag_enabled=true", Version.fromString("8.16.0"), null),
ML_SCALE_FROM_ZERO("es.ml_scale_from_zero_feature_flag_enabled=true", Version.fromString("8.16.0"), null);

public final String systemProperty;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,6 @@ public void tearDown() throws Exception {

@SuppressWarnings("unchecked")
public void testInferDeploysDefaultElser() throws IOException {
assumeTrue("Default config requires a feature flag", DefaultElserFeatureFlag.isEnabled());
var model = getModel(ElasticsearchInternalService.DEFAULT_ELSER_ID);
assertDefaultElserConfig(model);

Expand Down Expand Up @@ -78,7 +77,6 @@ private static void assertDefaultElserConfig(Map<String, Object> modelConfig) {

@SuppressWarnings("unchecked")
public void testInferDeploysDefaultE5() throws IOException {
assumeTrue("Default config requires a feature flag", DefaultElserFeatureFlag.isEnabled());
var model = getModel(ElasticsearchInternalService.DEFAULT_E5_ID);
assertDefaultE5Config(model);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,18 +44,18 @@ public void testCRUD() throws IOException {
}

var getAllModels = getAllModels();
int numModels = DefaultElserFeatureFlag.isEnabled() ? 11 : 9;
int numModels = 11;
assertThat(getAllModels, hasSize(numModels));

var getSparseModels = getModels("_all", TaskType.SPARSE_EMBEDDING);
int numSparseModels = DefaultElserFeatureFlag.isEnabled() ? 6 : 5;
int numSparseModels = 6;
assertThat(getSparseModels, hasSize(numSparseModels));
for (var sparseModel : getSparseModels) {
assertEquals("sparse_embedding", sparseModel.get("task_type"));
}

var getDenseModels = getModels("_all", TaskType.TEXT_EMBEDDING);
int numDenseModels = DefaultElserFeatureFlag.isEnabled() ? 5 : 4;
int numDenseModels = 5;
assertThat(getDenseModels, hasSize(numDenseModels));
for (var denseModel : getDenseModels) {
assertEquals("text_embedding", denseModel.get("task_type"));
Expand Down

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
import org.elasticsearch.xpack.inference.rank.random.RandomRankRetrieverBuilder;
import org.elasticsearch.xpack.inference.rank.textsimilarity.TextSimilarityRankRetrieverBuilder;

import java.util.HashSet;
import java.util.Set;

/**
Expand All @@ -24,16 +23,14 @@ public class InferenceFeatures implements FeatureSpecification {

@Override
public Set<NodeFeature> getFeatures() {
var features = new HashSet<NodeFeature>();
features.add(TextSimilarityRankRetrieverBuilder.TEXT_SIMILARITY_RERANKER_RETRIEVER_SUPPORTED);
features.add(RandomRankRetrieverBuilder.RANDOM_RERANKER_RETRIEVER_SUPPORTED);
features.add(SemanticTextFieldMapper.SEMANTIC_TEXT_SEARCH_INFERENCE_ID);
features.add(SemanticQueryBuilder.SEMANTIC_TEXT_INNER_HITS);
features.add(TextSimilarityRankRetrieverBuilder.TEXT_SIMILARITY_RERANKER_COMPOSITION_SUPPORTED);
if (DefaultElserFeatureFlag.isEnabled()) {
features.add(SemanticTextFieldMapper.SEMANTIC_TEXT_DEFAULT_ELSER_2);
}
return Set.copyOf(features);
return Set.of(
TextSimilarityRankRetrieverBuilder.TEXT_SIMILARITY_RERANKER_RETRIEVER_SUPPORTED,
RandomRankRetrieverBuilder.RANDOM_RERANKER_RETRIEVER_SUPPORTED,
SemanticTextFieldMapper.SEMANTIC_TEXT_SEARCH_INFERENCE_ID,
SemanticQueryBuilder.SEMANTIC_TEXT_INNER_HITS,
SemanticTextFieldMapper.SEMANTIC_TEXT_DEFAULT_ELSER_2,
TextSimilarityRankRetrieverBuilder.TEXT_SIMILARITY_RERANKER_COMPOSITION_SUPPORTED
);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -227,10 +227,8 @@ public Collection<?> createComponents(PluginServices services) {
// reference correctly
var registry = new InferenceServiceRegistry(inferenceServices, factoryContext);
registry.init(services.client());
if (DefaultElserFeatureFlag.isEnabled()) {
for (var service : registry.getServices().values()) {
service.defaultConfigIds().forEach(modelRegistry::addDefaultIds);
}
for (var service : registry.getServices().values()) {
service.defaultConfigIds().forEach(modelRegistry::addDefaultIds);
}
inferenceServiceRegistry.set(registry);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,6 @@
import org.elasticsearch.xcontent.XContentParserConfiguration;
import org.elasticsearch.xpack.core.ml.inference.results.MlTextEmbeddingResults;
import org.elasticsearch.xpack.core.ml.inference.results.TextExpansionResults;
import org.elasticsearch.xpack.inference.DefaultElserFeatureFlag;

import java.io.IOException;
import java.util.ArrayList;
Expand Down Expand Up @@ -111,16 +110,12 @@ public static class Builder extends FieldMapper.Builder {
INFERENCE_ID_FIELD,
false,
mapper -> ((SemanticTextFieldType) mapper.fieldType()).inferenceId,
DefaultElserFeatureFlag.isEnabled() ? DEFAULT_ELSER_2_INFERENCE_ID : null
DEFAULT_ELSER_2_INFERENCE_ID
).addValidator(v -> {
if (Strings.isEmpty(v)) {
// If the default ELSER feature flag is enabled, the only way we get here is if the user explicitly sets the param to an
// empty value. However, if the feature flag is disabled, we can get here if the user didn't set the param.
// Adjust the error message appropriately.
String message = DefaultElserFeatureFlag.isEnabled()
? "[" + INFERENCE_ID_FIELD + "] on mapper [" + leafName() + "] of type [" + CONTENT_TYPE + "] must not be empty"
: "[" + INFERENCE_ID_FIELD + "] on mapper [" + leafName() + "] of type [" + CONTENT_TYPE + "] must be specified";
throw new IllegalArgumentException(message);
throw new IllegalArgumentException(
"[" + INFERENCE_ID_FIELD + "] on mapper [" + leafName() + "] of type [" + CONTENT_TYPE + "] must not be empty"
);
}
});

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,7 @@
import org.elasticsearch.rest.ServerlessScope;
import org.elasticsearch.rest.action.RestToXContentListener;
import org.elasticsearch.xpack.core.inference.action.GetInferenceModelAction;
import org.elasticsearch.xpack.inference.DefaultElserFeatureFlag;

import java.util.Collections;
import java.util.HashSet;
import java.util.List;
import java.util.Set;

Expand Down Expand Up @@ -69,11 +66,6 @@ protected RestChannelConsumer prepareRequest(RestRequest restRequest, NodeClient

@Override
public Set<String> supportedCapabilities() {
Set<String> capabilities = new HashSet<>();
if (DefaultElserFeatureFlag.isEnabled()) {
capabilities.add(DEFAULT_ELSER_2_CAPABILITY);
}

return Collections.unmodifiableSet(capabilities);
return Set.of(DEFAULT_ELSER_2_CAPABILITY);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@
import org.elasticsearch.xpack.core.ml.inference.TrainedModelPrefixStrings;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfigUpdate;
import org.elasticsearch.xpack.core.ml.utils.MlPlatformArchitecturesUtil;
import org.elasticsearch.xpack.inference.DefaultElserFeatureFlag;
import org.elasticsearch.xpack.inference.InferencePlugin;

import java.io.IOException;
Expand Down Expand Up @@ -296,11 +295,6 @@ protected void maybeStartDeployment(
InferModelAction.Request request,
ActionListener<InferModelAction.Response> listener
) {
if (DefaultElserFeatureFlag.isEnabled() == false) {
listener.onFailure(e);
return;
}

if (isDefaultId(model.getInferenceEntityId()) && ExceptionsHelper.unwrapCause(e) instanceof ResourceNotFoundException) {
this.start(model, request.getInferenceTimeout(), listener.delegateFailureAndWrap((l, started) -> {
client.execute(InferModelAction.INSTANCE, request, listener);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,6 @@
import org.elasticsearch.xcontent.XContentParser;
import org.elasticsearch.xcontent.XContentType;
import org.elasticsearch.xcontent.json.JsonXContent;
import org.elasticsearch.xpack.inference.DefaultElserFeatureFlag;
import org.elasticsearch.xpack.inference.InferencePlugin;
import org.elasticsearch.xpack.inference.model.TestModel;
import org.junit.AssumptionViolatedException;
Expand Down Expand Up @@ -103,9 +102,6 @@ protected Collection<? extends Plugin> getPlugins() {
@Override
protected void minimalMapping(XContentBuilder b) throws IOException {
b.field("type", "semantic_text");
if (DefaultElserFeatureFlag.isEnabled() == false) {
b.field("inference_id", "test_model");
}
}

@Override
Expand Down Expand Up @@ -175,9 +171,7 @@ public void testDefaults() throws Exception {
DocumentMapper mapper = mapperService.documentMapper();
assertEquals(Strings.toString(fieldMapping), mapper.mappingSource().toString());
assertSemanticTextField(mapperService, fieldName, false);
if (DefaultElserFeatureFlag.isEnabled()) {
assertInferenceEndpoints(mapperService, fieldName, DEFAULT_ELSER_2_INFERENCE_ID, DEFAULT_ELSER_2_INFERENCE_ID);
}
assertInferenceEndpoints(mapperService, fieldName, DEFAULT_ELSER_2_INFERENCE_ID, DEFAULT_ELSER_2_INFERENCE_ID);

ParsedDocument doc1 = mapper.parse(source(this::writeField));
List<IndexableField> fields = doc1.rootDoc().getFields("field");
Expand Down Expand Up @@ -211,15 +205,13 @@ public void testSetInferenceEndpoints() throws IOException {
assertSerialization.accept(fieldMapping, mapperService);
}
{
if (DefaultElserFeatureFlag.isEnabled()) {
final XContentBuilder fieldMapping = fieldMapping(
b -> b.field("type", "semantic_text").field(SEARCH_INFERENCE_ID_FIELD, searchInferenceId)
);
final MapperService mapperService = createMapperService(fieldMapping);
assertSemanticTextField(mapperService, fieldName, false);
assertInferenceEndpoints(mapperService, fieldName, DEFAULT_ELSER_2_INFERENCE_ID, searchInferenceId);
assertSerialization.accept(fieldMapping, mapperService);
}
final XContentBuilder fieldMapping = fieldMapping(
b -> b.field("type", "semantic_text").field(SEARCH_INFERENCE_ID_FIELD, searchInferenceId)
);
final MapperService mapperService = createMapperService(fieldMapping);
assertSemanticTextField(mapperService, fieldName, false);
assertInferenceEndpoints(mapperService, fieldName, DEFAULT_ELSER_2_INFERENCE_ID, searchInferenceId);
assertSerialization.accept(fieldMapping, mapperService);
}
{
final XContentBuilder fieldMapping = fieldMapping(
Expand All @@ -246,26 +238,18 @@ public void testInvalidInferenceEndpoints() {
);
}
{
final String expectedMessage = DefaultElserFeatureFlag.isEnabled()
? "[inference_id] on mapper [field] of type [semantic_text] must not be empty"
: "[inference_id] on mapper [field] of type [semantic_text] must be specified";
Exception e = expectThrows(
MapperParsingException.class,
() -> createMapperService(fieldMapping(b -> b.field("type", "semantic_text").field(INFERENCE_ID_FIELD, "")))
);
assertThat(e.getMessage(), containsString(expectedMessage));
assertThat(e.getMessage(), containsString("[inference_id] on mapper [field] of type [semantic_text] must not be empty"));
}
{
if (DefaultElserFeatureFlag.isEnabled()) {
Exception e = expectThrows(
MapperParsingException.class,
() -> createMapperService(fieldMapping(b -> b.field("type", "semantic_text").field(SEARCH_INFERENCE_ID_FIELD, "")))
);
assertThat(
e.getMessage(),
containsString("[search_inference_id] on mapper [field] of type [semantic_text] must not be empty")
);
}
Exception e = expectThrows(
MapperParsingException.class,
() -> createMapperService(fieldMapping(b -> b.field("type", "semantic_text").field(SEARCH_INFERENCE_ID_FIELD, "")))
);
assertThat(e.getMessage(), containsString("[search_inference_id] on mapper [field] of type [semantic_text] must not be empty"));
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -551,7 +551,7 @@ setup:
---
"Calculates embeddings using the default ELSER 2 endpoint":
- requires:
reason: "default ELSER 2 inference ID is behind a feature flag"
reason: "default ELSER 2 inference ID is enabled via a capability"
test_runner_features: [capabilities]
capabilities:
- method: GET
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -843,7 +843,7 @@ setup:
---
"Query a field that uses the default ELSER 2 endpoint":
- requires:
reason: "default ELSER 2 inference ID is behind a feature flag"
reason: "default ELSER 2 inference ID is enabled via a capability"
test_runner_features: [capabilities]
capabilities:
- method: GET
Expand Down