Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding chunking settings to GoogleVertexAiService, AzureAiStudioService, and AlibabaCloudSearchService #113981

Merged
6 changes: 6 additions & 0 deletions docs/changelog/113981.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
pr: 113981
summary: "Adding chunking settings to `GoogleVertexAiService,` `AzureAiStudioService,`\
\ and `AlibabaCloudSearchService`"
area: Machine Learning
type: enhancement
issues: []
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import org.elasticsearch.core.TimeValue;
import org.elasticsearch.inference.ChunkedInferenceServiceResults;
import org.elasticsearch.inference.ChunkingOptions;
import org.elasticsearch.inference.ChunkingSettings;
import org.elasticsearch.inference.InferenceService;
import org.elasticsearch.inference.InferenceServiceResults;
import org.elasticsearch.inference.InputType;
Expand All @@ -24,6 +25,8 @@
import org.elasticsearch.inference.SimilarityMeasure;
import org.elasticsearch.inference.TaskType;
import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.xpack.core.inference.ChunkingSettingsFeatureFlag;
import org.elasticsearch.xpack.inference.chunking.ChunkingSettingsBuilder;
import org.elasticsearch.xpack.inference.chunking.EmbeddingRequestChunker;
import org.elasticsearch.xpack.inference.external.action.alibabacloudsearch.AlibabaCloudSearchActionCreator;
import org.elasticsearch.xpack.inference.external.http.sender.DocumentsOnlyInput;
Expand Down Expand Up @@ -74,11 +77,19 @@ public void parseRequestConfig(
Map<String, Object> serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS);
Map<String, Object> taskSettingsMap = removeFromMapOrDefaultEmpty(config, ModelConfigurations.TASK_SETTINGS);

ChunkingSettings chunkingSettings = null;
if (ChunkingSettingsFeatureFlag.isEnabled() && List.of(TaskType.TEXT_EMBEDDING, TaskType.SPARSE_EMBEDDING).contains(taskType)) {
chunkingSettings = ChunkingSettingsBuilder.fromMap(
removeFromMapOrDefaultEmpty(config, ModelConfigurations.CHUNKING_SETTINGS)
);
}

AlibabaCloudSearchModel model = createModel(
inferenceEntityId,
taskType,
serviceSettingsMap,
taskSettingsMap,
chunkingSettings,
serviceSettingsMap,
TaskType.unsupportedTaskTypeErrorMsg(taskType, NAME),
ConfigurationParseContext.REQUEST
Expand All @@ -99,6 +110,7 @@ private static AlibabaCloudSearchModel createModelWithoutLoggingDeprecations(
TaskType taskType,
Map<String, Object> serviceSettings,
Map<String, Object> taskSettings,
ChunkingSettings chunkingSettings,
@Nullable Map<String, Object> secretSettings,
String failureMessage
) {
Expand All @@ -107,6 +119,7 @@ private static AlibabaCloudSearchModel createModelWithoutLoggingDeprecations(
taskType,
serviceSettings,
taskSettings,
chunkingSettings,
secretSettings,
failureMessage,
ConfigurationParseContext.PERSISTENT
Expand All @@ -118,6 +131,7 @@ private static AlibabaCloudSearchModel createModel(
TaskType taskType,
Map<String, Object> serviceSettings,
Map<String, Object> taskSettings,
ChunkingSettings chunkingSettings,
@Nullable Map<String, Object> secretSettings,
String failureMessage,
ConfigurationParseContext context
Expand All @@ -129,6 +143,7 @@ private static AlibabaCloudSearchModel createModel(
NAME,
serviceSettings,
taskSettings,
chunkingSettings,
secretSettings,
context
);
Expand All @@ -138,6 +153,7 @@ private static AlibabaCloudSearchModel createModel(
NAME,
serviceSettings,
taskSettings,
chunkingSettings,
secretSettings,
context
);
Expand Down Expand Up @@ -174,11 +190,17 @@ public AlibabaCloudSearchModel parsePersistedConfigWithSecrets(
Map<String, Object> taskSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.TASK_SETTINGS);
Map<String, Object> secretSettingsMap = removeFromMapOrThrowIfNull(secrets, ModelSecrets.SECRET_SETTINGS);

ChunkingSettings chunkingSettings = null;
if (ChunkingSettingsFeatureFlag.isEnabled() && List.of(TaskType.TEXT_EMBEDDING, TaskType.SPARSE_EMBEDDING).contains(taskType)) {
chunkingSettings = ChunkingSettingsBuilder.fromMap(removeFromMapOrDefaultEmpty(config, ModelConfigurations.CHUNKING_SETTINGS));
}

return createModelWithoutLoggingDeprecations(
inferenceEntityId,
taskType,
serviceSettingsMap,
taskSettingsMap,
chunkingSettings,
secretSettingsMap,
parsePersistedConfigErrorMsg(inferenceEntityId, NAME)
);
Expand All @@ -189,11 +211,17 @@ public AlibabaCloudSearchModel parsePersistedConfig(String inferenceEntityId, Ta
Map<String, Object> serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS);
Map<String, Object> taskSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.TASK_SETTINGS);

ChunkingSettings chunkingSettings = null;
if (ChunkingSettingsFeatureFlag.isEnabled() && List.of(TaskType.TEXT_EMBEDDING, TaskType.SPARSE_EMBEDDING).contains(taskType)) {
chunkingSettings = ChunkingSettingsBuilder.fromMap(removeFromMapOrDefaultEmpty(config, ModelConfigurations.CHUNKING_SETTINGS));
}

return createModelWithoutLoggingDeprecations(
inferenceEntityId,
taskType,
serviceSettingsMap,
taskSettingsMap,
chunkingSettings,
null,
parsePersistedConfigErrorMsg(inferenceEntityId, NAME)
);
Expand Down Expand Up @@ -238,17 +266,36 @@ protected void doChunkedInfer(
AlibabaCloudSearchModel alibabaCloudSearchModel = (AlibabaCloudSearchModel) model;
var actionCreator = new AlibabaCloudSearchActionCreator(getSender(), getServiceComponents());

var batchedRequests = new EmbeddingRequestChunker(
inputs.getInputs(),
EMBEDDING_MAX_BATCH_SIZE,
EmbeddingRequestChunker.EmbeddingType.FLOAT
).batchRequestsWithListeners(listener);
List<EmbeddingRequestChunker.BatchRequestAndListener> batchedRequests;
if (ChunkingSettingsFeatureFlag.isEnabled()) {
batchedRequests = new EmbeddingRequestChunker(
inputs.getInputs(),
EMBEDDING_MAX_BATCH_SIZE,
getEmbeddingTypeFromTaskType(alibabaCloudSearchModel.getTaskType()),
alibabaCloudSearchModel.getConfigurations().getChunkingSettings()
).batchRequestsWithListeners(listener);
} else {
batchedRequests = new EmbeddingRequestChunker(
inputs.getInputs(),
EMBEDDING_MAX_BATCH_SIZE,
getEmbeddingTypeFromTaskType(alibabaCloudSearchModel.getTaskType())
).batchRequestsWithListeners(listener);
}

for (var request : batchedRequests) {
var action = alibabaCloudSearchModel.accept(actionCreator, taskSettings, inputType);
action.execute(new DocumentsOnlyInput(request.batch().inputs()), timeout, request.listener());
}
}

private EmbeddingRequestChunker.EmbeddingType getEmbeddingTypeFromTaskType(TaskType taskType) {
return switch (taskType) {
case TaskType.TEXT_EMBEDDING -> EmbeddingRequestChunker.EmbeddingType.FLOAT;
case TaskType.SPARSE_EMBEDDING -> EmbeddingRequestChunker.EmbeddingType.SPARSE;
default -> throw new IllegalArgumentException("Unsupported task type for chunking: " + taskType);
};
}

/**
* For text embedding models get the embedding size and
* update the service settings.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
package org.elasticsearch.xpack.inference.services.alibabacloudsearch.embeddings;

import org.elasticsearch.core.Nullable;
import org.elasticsearch.inference.ChunkingSettings;
import org.elasticsearch.inference.InputType;
import org.elasticsearch.inference.ModelConfigurations;
import org.elasticsearch.inference.ModelSecrets;
Expand Down Expand Up @@ -39,6 +40,7 @@ public AlibabaCloudSearchEmbeddingsModel(
String service,
Map<String, Object> serviceSettings,
Map<String, Object> taskSettings,
ChunkingSettings chunkingSettings,
@Nullable Map<String, Object> secrets,
ConfigurationParseContext context
) {
Expand All @@ -48,6 +50,7 @@ public AlibabaCloudSearchEmbeddingsModel(
service,
AlibabaCloudSearchEmbeddingsServiceSettings.fromMap(serviceSettings, context),
AlibabaCloudSearchEmbeddingsTaskSettings.fromMap(taskSettings),
chunkingSettings,
DefaultSecretSettings.fromMap(secrets)
);
}
Expand All @@ -59,10 +62,11 @@ public AlibabaCloudSearchEmbeddingsModel(
String service,
AlibabaCloudSearchEmbeddingsServiceSettings serviceSettings,
AlibabaCloudSearchEmbeddingsTaskSettings taskSettings,
ChunkingSettings chunkingSettings,
@Nullable DefaultSecretSettings secretSettings
) {
super(
new ModelConfigurations(modelId, taskType, service, serviceSettings, taskSettings),
new ModelConfigurations(modelId, taskType, service, serviceSettings, taskSettings, chunkingSettings),
new ModelSecrets(secretSettings),
serviceSettings.getCommonSettings()
);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.core.Nullable;
import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper;
import org.elasticsearch.inference.ModelConfigurations;
import org.elasticsearch.inference.ServiceSettings;
import org.elasticsearch.inference.SimilarityMeasure;
Expand Down Expand Up @@ -81,10 +82,21 @@ public SimilarityMeasure getSimilarity() {
return similarity;
}

public Integer getDimensions() {
@Override
public Integer dimensions() {
return dimensions;
}

@Override
public SimilarityMeasure similarity() {
return similarity;
}

@Override
public DenseVectorFieldMapper.ElementType elementType() {
return DenseVectorFieldMapper.ElementType.FLOAT;
}

public Integer getMaxInputTokens() {
return maxInputTokens;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
package org.elasticsearch.xpack.inference.services.alibabacloudsearch.sparse;

import org.elasticsearch.core.Nullable;
import org.elasticsearch.inference.ChunkingSettings;
import org.elasticsearch.inference.InputType;
import org.elasticsearch.inference.ModelConfigurations;
import org.elasticsearch.inference.ModelSecrets;
Expand Down Expand Up @@ -39,6 +40,7 @@ public AlibabaCloudSearchSparseModel(
String service,
Map<String, Object> serviceSettings,
Map<String, Object> taskSettings,
ChunkingSettings chunkingSettings,
@Nullable Map<String, Object> secrets,
ConfigurationParseContext context
) {
Expand All @@ -48,6 +50,7 @@ public AlibabaCloudSearchSparseModel(
service,
AlibabaCloudSearchSparseServiceSettings.fromMap(serviceSettings, context),
AlibabaCloudSearchSparseTaskSettings.fromMap(taskSettings),
chunkingSettings,
DefaultSecretSettings.fromMap(secrets)
);
}
Expand All @@ -59,10 +62,11 @@ public AlibabaCloudSearchSparseModel(
String service,
AlibabaCloudSearchSparseServiceSettings serviceSettings,
AlibabaCloudSearchSparseTaskSettings taskSettings,
ChunkingSettings chunkingSettings,
@Nullable DefaultSecretSettings secretSettings
) {
super(
new ModelConfigurations(modelId, taskType, service, serviceSettings, taskSettings),
new ModelConfigurations(modelId, taskType, service, serviceSettings, taskSettings, chunkingSettings),
new ModelSecrets(secretSettings),
serviceSettings.getCommonSettings()
);
Expand Down
Loading