Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding chunking settings to GoogleVertexAiService, AzureAiStudioService, and AlibabaCloudSearchService #113981

Merged
6 changes: 6 additions & 0 deletions docs/changelog/113981.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
pr: 113981
summary: "Adding chunking settings to `GoogleVertexAiService,` `AzureAiStudioService,`\
\ and `AlibabaCloudSearchService`"
area: Machine Learning
type: enhancement
issues: []
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import org.elasticsearch.core.TimeValue;
import org.elasticsearch.inference.ChunkedInferenceServiceResults;
import org.elasticsearch.inference.ChunkingOptions;
import org.elasticsearch.inference.ChunkingSettings;
import org.elasticsearch.inference.InferenceService;
import org.elasticsearch.inference.InferenceServiceResults;
import org.elasticsearch.inference.InputType;
Expand All @@ -24,6 +25,8 @@
import org.elasticsearch.inference.SimilarityMeasure;
import org.elasticsearch.inference.TaskType;
import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.xpack.core.inference.ChunkingSettingsFeatureFlag;
import org.elasticsearch.xpack.inference.chunking.ChunkingSettingsBuilder;
import org.elasticsearch.xpack.inference.chunking.EmbeddingRequestChunker;
import org.elasticsearch.xpack.inference.external.action.alibabacloudsearch.AlibabaCloudSearchActionCreator;
import org.elasticsearch.xpack.inference.external.http.sender.DocumentsOnlyInput;
Expand Down Expand Up @@ -74,11 +77,19 @@ public void parseRequestConfig(
Map<String, Object> serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS);
Map<String, Object> taskSettingsMap = removeFromMapOrDefaultEmpty(config, ModelConfigurations.TASK_SETTINGS);

ChunkingSettings chunkingSettings = null;
if (ChunkingSettingsFeatureFlag.isEnabled() && List.of(TaskType.TEXT_EMBEDDING, TaskType.SPARSE_EMBEDDING).contains(taskType)) {
chunkingSettings = ChunkingSettingsBuilder.fromMap(
removeFromMapOrDefaultEmpty(config, ModelConfigurations.CHUNKING_SETTINGS)
);
}

AlibabaCloudSearchModel model = createModel(
inferenceEntityId,
taskType,
serviceSettingsMap,
taskSettingsMap,
chunkingSettings,
serviceSettingsMap,
TaskType.unsupportedTaskTypeErrorMsg(taskType, NAME),
ConfigurationParseContext.REQUEST
Expand All @@ -99,6 +110,7 @@ private static AlibabaCloudSearchModel createModelWithoutLoggingDeprecations(
TaskType taskType,
Map<String, Object> serviceSettings,
Map<String, Object> taskSettings,
ChunkingSettings chunkingSettings,
@Nullable Map<String, Object> secretSettings,
String failureMessage
) {
Expand All @@ -107,6 +119,7 @@ private static AlibabaCloudSearchModel createModelWithoutLoggingDeprecations(
taskType,
serviceSettings,
taskSettings,
chunkingSettings,
secretSettings,
failureMessage,
ConfigurationParseContext.PERSISTENT
Expand All @@ -118,6 +131,7 @@ private static AlibabaCloudSearchModel createModel(
TaskType taskType,
Map<String, Object> serviceSettings,
Map<String, Object> taskSettings,
ChunkingSettings chunkingSettings,
@Nullable Map<String, Object> secretSettings,
String failureMessage,
ConfigurationParseContext context
Expand All @@ -129,6 +143,7 @@ private static AlibabaCloudSearchModel createModel(
NAME,
serviceSettings,
taskSettings,
chunkingSettings,
secretSettings,
context
);
Expand All @@ -138,6 +153,7 @@ private static AlibabaCloudSearchModel createModel(
NAME,
serviceSettings,
taskSettings,
chunkingSettings,
secretSettings,
context
);
Expand Down Expand Up @@ -174,11 +190,17 @@ public AlibabaCloudSearchModel parsePersistedConfigWithSecrets(
Map<String, Object> taskSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.TASK_SETTINGS);
Map<String, Object> secretSettingsMap = removeFromMapOrThrowIfNull(secrets, ModelSecrets.SECRET_SETTINGS);

ChunkingSettings chunkingSettings = null;
if (ChunkingSettingsFeatureFlag.isEnabled() && List.of(TaskType.TEXT_EMBEDDING, TaskType.SPARSE_EMBEDDING).contains(taskType)) {
chunkingSettings = ChunkingSettingsBuilder.fromMap(removeFromMapOrDefaultEmpty(config, ModelConfigurations.CHUNKING_SETTINGS));
}

return createModelWithoutLoggingDeprecations(
inferenceEntityId,
taskType,
serviceSettingsMap,
taskSettingsMap,
chunkingSettings,
secretSettingsMap,
parsePersistedConfigErrorMsg(inferenceEntityId, NAME)
);
Expand All @@ -189,11 +211,17 @@ public AlibabaCloudSearchModel parsePersistedConfig(String inferenceEntityId, Ta
Map<String, Object> serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS);
Map<String, Object> taskSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.TASK_SETTINGS);

ChunkingSettings chunkingSettings = null;
if (ChunkingSettingsFeatureFlag.isEnabled() && List.of(TaskType.TEXT_EMBEDDING, TaskType.SPARSE_EMBEDDING).contains(taskType)) {
chunkingSettings = ChunkingSettingsBuilder.fromMap(removeFromMapOrDefaultEmpty(config, ModelConfigurations.CHUNKING_SETTINGS));
}

return createModelWithoutLoggingDeprecations(
inferenceEntityId,
taskType,
serviceSettingsMap,
taskSettingsMap,
chunkingSettings,
null,
parsePersistedConfigErrorMsg(inferenceEntityId, NAME)
);
Expand Down Expand Up @@ -238,11 +266,22 @@ protected void doChunkedInfer(
AlibabaCloudSearchModel alibabaCloudSearchModel = (AlibabaCloudSearchModel) model;
var actionCreator = new AlibabaCloudSearchActionCreator(getSender(), getServiceComponents());

var batchedRequests = new EmbeddingRequestChunker(
inputs.getInputs(),
EMBEDDING_MAX_BATCH_SIZE,
EmbeddingRequestChunker.EmbeddingType.FLOAT
).batchRequestsWithListeners(listener);
List<EmbeddingRequestChunker.BatchRequestAndListener> batchedRequests;
if (ChunkingSettingsFeatureFlag.isEnabled()) {
batchedRequests = new EmbeddingRequestChunker(
inputs.getInputs(),
EMBEDDING_MAX_BATCH_SIZE,
EmbeddingRequestChunker.EmbeddingType.FLOAT,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Alibaba supports sparse embeddings. If the task type is sparse then the embedding type should also be sparse

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch, I've updated this and also noticed that we were missing some overrides in the text embedding service settings for Alibaba which were preventing us from using it for semantic text fields. I've updated these both now.

alibabaCloudSearchModel.getConfigurations().getChunkingSettings()
).batchRequestsWithListeners(listener);
} else {
batchedRequests = new EmbeddingRequestChunker(
inputs.getInputs(),
EMBEDDING_MAX_BATCH_SIZE,
EmbeddingRequestChunker.EmbeddingType.FLOAT
).batchRequestsWithListeners(listener);
}

for (var request : batchedRequests) {
var action = alibabaCloudSearchModel.accept(actionCreator, taskSettings, inputType);
action.execute(new DocumentsOnlyInput(request.batch().inputs()), timeout, request.listener());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
package org.elasticsearch.xpack.inference.services.alibabacloudsearch.embeddings;

import org.elasticsearch.core.Nullable;
import org.elasticsearch.inference.ChunkingSettings;
import org.elasticsearch.inference.InputType;
import org.elasticsearch.inference.ModelConfigurations;
import org.elasticsearch.inference.ModelSecrets;
Expand Down Expand Up @@ -39,6 +40,7 @@ public AlibabaCloudSearchEmbeddingsModel(
String service,
Map<String, Object> serviceSettings,
Map<String, Object> taskSettings,
ChunkingSettings chunkingSettings,
@Nullable Map<String, Object> secrets,
ConfigurationParseContext context
) {
Expand All @@ -48,6 +50,7 @@ public AlibabaCloudSearchEmbeddingsModel(
service,
AlibabaCloudSearchEmbeddingsServiceSettings.fromMap(serviceSettings, context),
AlibabaCloudSearchEmbeddingsTaskSettings.fromMap(taskSettings),
chunkingSettings,
DefaultSecretSettings.fromMap(secrets)
);
}
Expand All @@ -59,10 +62,11 @@ public AlibabaCloudSearchEmbeddingsModel(
String service,
AlibabaCloudSearchEmbeddingsServiceSettings serviceSettings,
AlibabaCloudSearchEmbeddingsTaskSettings taskSettings,
ChunkingSettings chunkingSettings,
@Nullable DefaultSecretSettings secretSettings
) {
super(
new ModelConfigurations(modelId, taskType, service, serviceSettings, taskSettings),
new ModelConfigurations(modelId, taskType, service, serviceSettings, taskSettings, chunkingSettings),
new ModelSecrets(secretSettings),
serviceSettings.getCommonSettings()
);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
package org.elasticsearch.xpack.inference.services.alibabacloudsearch.sparse;

import org.elasticsearch.core.Nullable;
import org.elasticsearch.inference.ChunkingSettings;
import org.elasticsearch.inference.InputType;
import org.elasticsearch.inference.ModelConfigurations;
import org.elasticsearch.inference.ModelSecrets;
Expand Down Expand Up @@ -39,6 +40,7 @@ public AlibabaCloudSearchSparseModel(
String service,
Map<String, Object> serviceSettings,
Map<String, Object> taskSettings,
ChunkingSettings chunkingSettings,
@Nullable Map<String, Object> secrets,
ConfigurationParseContext context
) {
Expand All @@ -48,6 +50,7 @@ public AlibabaCloudSearchSparseModel(
service,
AlibabaCloudSearchSparseServiceSettings.fromMap(serviceSettings, context),
AlibabaCloudSearchSparseTaskSettings.fromMap(taskSettings),
chunkingSettings,
DefaultSecretSettings.fromMap(secrets)
);
}
Expand All @@ -59,10 +62,11 @@ public AlibabaCloudSearchSparseModel(
String service,
AlibabaCloudSearchSparseServiceSettings serviceSettings,
AlibabaCloudSearchSparseTaskSettings taskSettings,
ChunkingSettings chunkingSettings,
@Nullable DefaultSecretSettings secretSettings
) {
super(
new ModelConfigurations(modelId, taskType, service, serviceSettings, taskSettings),
new ModelConfigurations(modelId, taskType, service, serviceSettings, taskSettings, chunkingSettings),
new ModelSecrets(secretSettings),
serviceSettings.getCommonSettings()
);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import org.elasticsearch.core.TimeValue;
import org.elasticsearch.inference.ChunkedInferenceServiceResults;
import org.elasticsearch.inference.ChunkingOptions;
import org.elasticsearch.inference.ChunkingSettings;
import org.elasticsearch.inference.InferenceServiceResults;
import org.elasticsearch.inference.InputType;
import org.elasticsearch.inference.Model;
Expand All @@ -24,6 +25,8 @@
import org.elasticsearch.inference.SimilarityMeasure;
import org.elasticsearch.inference.TaskType;
import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.xpack.core.inference.ChunkingSettingsFeatureFlag;
import org.elasticsearch.xpack.inference.chunking.ChunkingSettingsBuilder;
import org.elasticsearch.xpack.inference.chunking.EmbeddingRequestChunker;
import org.elasticsearch.xpack.inference.external.action.azureaistudio.AzureAiStudioActionCreator;
import org.elasticsearch.xpack.inference.external.http.sender.DocumentsOnlyInput;
Expand Down Expand Up @@ -90,11 +93,23 @@ protected void doChunkedInfer(
) {
if (model instanceof AzureAiStudioModel baseAzureAiStudioModel) {
var actionCreator = new AzureAiStudioActionCreator(getSender(), getServiceComponents());
var batchedRequests = new EmbeddingRequestChunker(
inputs.getInputs(),
EMBEDDING_MAX_BATCH_SIZE,
EmbeddingRequestChunker.EmbeddingType.FLOAT
).batchRequestsWithListeners(listener);

List<EmbeddingRequestChunker.BatchRequestAndListener> batchedRequests;
if (ChunkingSettingsFeatureFlag.isEnabled()) {
batchedRequests = new EmbeddingRequestChunker(
inputs.getInputs(),
EMBEDDING_MAX_BATCH_SIZE,
EmbeddingRequestChunker.EmbeddingType.FLOAT,
baseAzureAiStudioModel.getConfigurations().getChunkingSettings()
).batchRequestsWithListeners(listener);
} else {
batchedRequests = new EmbeddingRequestChunker(
inputs.getInputs(),
EMBEDDING_MAX_BATCH_SIZE,
EmbeddingRequestChunker.EmbeddingType.FLOAT
).batchRequestsWithListeners(listener);
}

for (var request : batchedRequests) {
var action = baseAzureAiStudioModel.accept(actionCreator, taskSettings);
action.execute(new DocumentsOnlyInput(request.batch().inputs()), timeout, request.listener());
Expand All @@ -115,11 +130,19 @@ public void parseRequestConfig(
Map<String, Object> serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS);
Map<String, Object> taskSettingsMap = removeFromMapOrDefaultEmpty(config, ModelConfigurations.TASK_SETTINGS);

ChunkingSettings chunkingSettings = null;
if (ChunkingSettingsFeatureFlag.isEnabled() && TaskType.TEXT_EMBEDDING.equals(taskType)) {
chunkingSettings = ChunkingSettingsBuilder.fromMap(
removeFromMapOrDefaultEmpty(config, ModelConfigurations.CHUNKING_SETTINGS)
);
}

AzureAiStudioModel model = createModel(
inferenceEntityId,
taskType,
serviceSettingsMap,
taskSettingsMap,
chunkingSettings,
serviceSettingsMap,
TaskType.unsupportedTaskTypeErrorMsg(taskType, NAME),
ConfigurationParseContext.REQUEST
Expand All @@ -146,11 +169,17 @@ public AzureAiStudioModel parsePersistedConfigWithSecrets(
Map<String, Object> taskSettingsMap = removeFromMapOrDefaultEmpty(config, ModelConfigurations.TASK_SETTINGS);
Map<String, Object> secretSettingsMap = removeFromMapOrDefaultEmpty(secrets, ModelSecrets.SECRET_SETTINGS);

ChunkingSettings chunkingSettings = null;
if (ChunkingSettingsFeatureFlag.isEnabled() && TaskType.TEXT_EMBEDDING.equals(taskType)) {
chunkingSettings = ChunkingSettingsBuilder.fromMap(removeFromMapOrDefaultEmpty(config, ModelConfigurations.CHUNKING_SETTINGS));
}

return createModelFromPersistent(
inferenceEntityId,
taskType,
serviceSettingsMap,
taskSettingsMap,
chunkingSettings,
secretSettingsMap,
parsePersistedConfigErrorMsg(inferenceEntityId, NAME)
);
Expand All @@ -161,11 +190,17 @@ public Model parsePersistedConfig(String inferenceEntityId, TaskType taskType, M
Map<String, Object> serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS);
Map<String, Object> taskSettingsMap = removeFromMapOrDefaultEmpty(config, ModelConfigurations.TASK_SETTINGS);

ChunkingSettings chunkingSettings = null;
if (ChunkingSettingsFeatureFlag.isEnabled() && TaskType.TEXT_EMBEDDING.equals(taskType)) {
chunkingSettings = ChunkingSettingsBuilder.fromMap(removeFromMapOrDefaultEmpty(config, ModelConfigurations.CHUNKING_SETTINGS));
}

return createModelFromPersistent(
inferenceEntityId,
taskType,
serviceSettingsMap,
taskSettingsMap,
chunkingSettings,
null,
parsePersistedConfigErrorMsg(inferenceEntityId, NAME)
);
Expand All @@ -186,6 +221,7 @@ private static AzureAiStudioModel createModel(
TaskType taskType,
Map<String, Object> serviceSettings,
Map<String, Object> taskSettings,
ChunkingSettings chunkingSettings,
@Nullable Map<String, Object> secretSettings,
String failureMessage,
ConfigurationParseContext context
Expand All @@ -198,6 +234,7 @@ private static AzureAiStudioModel createModel(
NAME,
serviceSettings,
taskSettings,
chunkingSettings,
secretSettings,
context
);
Expand Down Expand Up @@ -235,6 +272,7 @@ private AzureAiStudioModel createModelFromPersistent(
TaskType taskType,
Map<String, Object> serviceSettings,
Map<String, Object> taskSettings,
ChunkingSettings chunkingSettings,
Map<String, Object> secretSettings,
String failureMessage
) {
Expand All @@ -243,6 +281,7 @@ private AzureAiStudioModel createModelFromPersistent(
taskType,
serviceSettings,
taskSettings,
chunkingSettings,
secretSettings,
failureMessage,
ConfigurationParseContext.PERSISTENT
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
package org.elasticsearch.xpack.inference.services.azureaistudio.embeddings;

import org.elasticsearch.core.Nullable;
import org.elasticsearch.inference.ChunkingSettings;
import org.elasticsearch.inference.ModelConfigurations;
import org.elasticsearch.inference.ModelSecrets;
import org.elasticsearch.inference.TaskType;
Expand Down Expand Up @@ -44,9 +45,13 @@ public AzureAiStudioEmbeddingsModel(
String service,
AzureAiStudioEmbeddingsServiceSettings serviceSettings,
AzureAiStudioEmbeddingsTaskSettings taskSettings,
ChunkingSettings chunkingSettings,
DefaultSecretSettings secrets
) {
super(new ModelConfigurations(inferenceEntityId, taskType, service, serviceSettings, taskSettings), new ModelSecrets(secrets));
super(
new ModelConfigurations(inferenceEntityId, taskType, service, serviceSettings, taskSettings, chunkingSettings),
new ModelSecrets(secrets)
);
}

public AzureAiStudioEmbeddingsModel(
Expand All @@ -55,6 +60,7 @@ public AzureAiStudioEmbeddingsModel(
String service,
Map<String, Object> serviceSettings,
Map<String, Object> taskSettings,
ChunkingSettings chunkingSettings,
@Nullable Map<String, Object> secrets,
ConfigurationParseContext context
) {
Expand All @@ -64,6 +70,7 @@ public AzureAiStudioEmbeddingsModel(
service,
AzureAiStudioEmbeddingsServiceSettings.fromMap(serviceSettings, context),
AzureAiStudioEmbeddingsTaskSettings.fromMap(taskSettings),
chunkingSettings,
DefaultSecretSettings.fromMap(secrets)
);
}
Expand Down
Loading