Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

implement batch document optimization for text embedding processor #1217

Open
wants to merge 1 commit into
base: optimized-processor
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),

## [Unreleased 3.0](https://github.com/opensearch-project/neural-search/compare/2.x...HEAD)
### Features
- Add Optimized Text Embedding Processor ([#1191](https://github.com/opensearch-project/neural-search/pull/1191))
- Optimizing embedding generation in text embedding processor ([#1191](https://github.com/opensearch-project/neural-search/pull/1191))
### Enhancements
- Set neural-search plugin 3.0.0 baseline JDK version to JDK-21 ([#838](https://github.com/opensearch-project/neural-search/pull/838))
- Support different embedding types in model's response ([#1007](https://github.com/opensearch-project/neural-search/pull/1007))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@
import org.apache.commons.lang3.StringUtils;
import org.apache.commons.lang3.tuple.ImmutablePair;
import org.apache.commons.lang3.tuple.Pair;
import org.opensearch.action.get.MultiGetItemResponse;
import org.opensearch.action.get.MultiGetRequest;
import org.opensearch.common.collect.Tuple;
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.common.util.CollectionUtils;
Expand All @@ -42,6 +44,7 @@
import com.google.common.collect.ImmutableMap;

import lombok.extern.log4j.Log4j2;
import org.opensearch.neuralsearch.processor.optimization.InferenceFilter;
import org.opensearch.neuralsearch.util.ProcessorDocumentUtils;

/**
Expand All @@ -54,6 +57,8 @@ public abstract class InferenceProcessor extends AbstractBatchingProcessor {

public static final String MODEL_ID_FIELD = "model_id";
public static final String FIELD_MAP_FIELD = "field_map";
public static final String INDEX_FIELD = "_index";
public static final String ID_FIELD = "_id";
private static final BiFunction<Object, Object, Object> REMAPPING_FUNCTION = (v1, v2) -> {
if (v1 instanceof Collection && v2 instanceof Collection) {
((Collection) v1).addAll((Collection) v2);
Expand Down Expand Up @@ -169,15 +174,59 @@ void preprocessIngestDocument(IngestDocument ingestDocument) {
*/
abstract void doBatchExecute(List<String> inferenceList, Consumer<List<?>> handler, Consumer<Exception> onException);

/**
* This is the method which filters each data inference before batch execution.
* @param ingestDocumentWrappers a list of ingestDocuments.
* @param existingDocuments a map with key: document ID and value: document as map.
* @param inferenceFilter class used for filtering process map and copying embeddings
* @param handler a callback handler to handle inference results which is a list of objects.
*/
protected void subBatchExecuteWithFilter(
List<IngestDocumentWrapper> ingestDocumentWrappers,
Map<String, Map<String, Object>> existingDocuments,
InferenceFilter inferenceFilter,
Consumer<List<IngestDocumentWrapper>> handler
) {
List<DataForInference> dataForInferences = getDataForInference(ingestDocumentWrappers);
List<DataForInference> filteredDataForInference = new ArrayList<>();
for (DataForInference dataForInference : dataForInferences) {
IngestDocumentWrapper ingestDocumentWrapper = dataForInference.getIngestDocumentWrapper();
Map<String, Object> processMap = dataForInference.getProcessMap();
Map<String, Object> document = ingestDocumentWrapper.getIngestDocument().getSourceAndMetadata();
Object id = document.get(ID_FIELD);
// insert non-filtered dataForInference if existing document does not exist
if (Objects.isNull(id) || existingDocuments.containsKey(id.toString()) == false) {
filteredDataForInference.add(dataForInference);
continue;
}
// filter dataForInference when existing document exists
String docId = id.toString();
Map<String, Object> existingDocument = existingDocuments.get(docId);
Map<String, Object> filteredProcessMap = inferenceFilter.filter(existingDocument, document, processMap);
List<String> filteredInferenceList = createInferenceList(filteredProcessMap);
filteredDataForInference.add(new DataForInference(ingestDocumentWrapper, filteredProcessMap, filteredInferenceList));
}
List<String> filteredInferenceList = constructInferenceTexts(filteredDataForInference);
doSubBatchExecute(filteredDataForInference, filteredInferenceList, ingestDocumentWrappers, handler);
}

@Override
public void subBatchExecute(List<IngestDocumentWrapper> ingestDocumentWrappers, Consumer<List<IngestDocumentWrapper>> handler) {
if (CollectionUtils.isEmpty(ingestDocumentWrappers)) {
handler.accept(Collections.emptyList());
return;
}

List<DataForInference> dataForInferences = getDataForInference(ingestDocumentWrappers);
List<String> inferenceList = constructInferenceTexts(dataForInferences);
doSubBatchExecute(dataForInferences, inferenceList, ingestDocumentWrappers, handler);
}

private void doSubBatchExecute(
List<DataForInference> dataForInferences,
List<String> inferenceList,
List<IngestDocumentWrapper> ingestDocumentWrappers,
Consumer<List<IngestDocumentWrapper>> handler
) {
if (inferenceList.isEmpty()) {
handler.accept(ingestDocumentWrappers);
return;
Expand Down Expand Up @@ -214,6 +263,19 @@ public void subBatchExecute(List<IngestDocumentWrapper> ingestDocumentWrappers,
});
}

@Override
public void batchExecute(List<IngestDocumentWrapper> ingestDocumentWrappers, Consumer<List<IngestDocumentWrapper>> handler) {
super.batchExecute(ingestDocumentWrappers, handler);
}

protected List<List<IngestDocumentWrapper>> cutBatches(List<IngestDocumentWrapper> ingestDocumentWrappers) {
List<List<IngestDocumentWrapper>> batches = new ArrayList();
for (int i = 0; i < ingestDocumentWrappers.size(); i += this.batchSize) {
batches.add(ingestDocumentWrappers.subList(i, Math.min(i + this.batchSize, ingestDocumentWrappers.size())));
}
return batches;
}

private Tuple<List<String>, Map<Integer, Integer>> sortByLengthAndReturnOriginalOrder(List<String> inferenceList) {
List<Tuple<Integer, String>> docsWithIndex = new ArrayList<>();
for (int i = 0; i < inferenceList.size(); ++i) {
Expand Down Expand Up @@ -415,6 +477,36 @@ protected void setVectorFieldsToDocument(IngestDocument ingestDocument, Map<Stri
nlpResult.forEach(ingestDocument::setFieldValue);
}

/**
* This method creates a MultiGetRequest from a list of ingest documents to be fetched for comparison
* @param ingestDocumentWrappers, list of ingest documents
* */
protected MultiGetRequest buildMultiGetRequest(List<IngestDocumentWrapper> ingestDocumentWrappers) {
MultiGetRequest multiGetRequest = new MultiGetRequest();
for (IngestDocumentWrapper ingestDocumentWrapper : ingestDocumentWrappers) {
Object index = ingestDocumentWrapper.getIngestDocument().getSourceAndMetadata().get(INDEX_FIELD);
Object id = ingestDocumentWrapper.getIngestDocument().getSourceAndMetadata().get(ID_FIELD);
if (Objects.nonNull(index) && Objects.nonNull(id)) {
multiGetRequest.add(index.toString(), id.toString());
}
}
return multiGetRequest;
}

/**
* This method creates a map of documents from MultiGetItemResponse where the key is document ID and value is corresponding document
* @param multiGetItemResponses, array of responses from Multi Get Request
* */
protected Map<String, Map<String, Object>> createDocumentMap(MultiGetItemResponse[] multiGetItemResponses) {
Map<String, Map<String, Object>> existingDocuments = new HashMap<>();
for (MultiGetItemResponse item : multiGetItemResponses) {
String id = item.getId();
Map<String, Object> existingDocument = item.getResponse().getSourceAsMap();
existingDocuments.put(id, existingDocument);
}
return existingDocuments;
}

@SuppressWarnings({ "unchecked" })
@VisibleForTesting
Map<String, Object> buildNLPResult(Map<String, Object> processorMap, List<?> results, Map<String, Object> sourceAndMetadataMap) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,19 +4,25 @@
*/
package org.opensearch.neuralsearch.processor;

import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.function.BiConsumer;
import java.util.function.Consumer;
import java.util.stream.Collectors;

import org.opensearch.action.get.GetAction;
import org.opensearch.action.get.GetRequest;
import org.opensearch.action.get.MultiGetAction;
import org.opensearch.action.get.MultiGetItemResponse;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.core.action.ActionListener;
import org.opensearch.env.Environment;
import org.opensearch.ingest.IngestDocument;
import org.opensearch.ingest.IngestDocumentWrapper;
import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor;

import lombok.extern.log4j.Log4j2;
Expand All @@ -26,6 +32,7 @@
/**
* This processor is used for user input data text embedding processing, model_id can be used to indicate which model user use,
* and field_map can be used to indicate which fields needs text embedding and the corresponding keys for the text embedding results.
* If skip_existing flag is on, Get/MultiGet request is made to compare between new document and existing document to skip existing embeddings
*/
@Log4j2
public final class TextEmbeddingProcessor extends InferenceProcessor {
Expand All @@ -34,8 +41,6 @@ public final class TextEmbeddingProcessor extends InferenceProcessor {
public static final String LIST_TYPE_NESTED_MAP_KEY = "knn";
public static final String SKIP_EXISTING = "skip_existing";
public static final boolean DEFAULT_SKIP_EXISTING = false;
private static final String INDEX_FIELD = "_index";
private static final String ID_FIELD = "_id";
private final OpenSearchClient openSearchClient;
private final boolean skipExisting;
private final TextEmbeddingInferenceFilter textEmbeddingInferenceFilter;
Expand Down Expand Up @@ -106,4 +111,44 @@ public void doBatchExecute(List<String> inferenceList, Consumer<List<?>> handler
ActionListener.wrap(handler::accept, onException)
);
}

@Override
public void batchExecute(List<IngestDocumentWrapper> ingestDocumentWrappers, Consumer<List<IngestDocumentWrapper>> handler) {
// skip existing flag is turned off. Call existing batchExecute without filtering
if (skipExisting == false) {
super.batchExecute(ingestDocumentWrappers, handler);
return;
}
if (ingestDocumentWrappers.isEmpty()) {
handler.accept(Collections.emptyList());
return;
}
// if skipExisting flag is turned on, inference texts in each document will be compared with existing documents fetched via
// MultiGet. TextEmbeddingInferenceFilter will be used to filter each inference texts
openSearchClient.execute(MultiGetAction.INSTANCE, buildMultiGetRequest(ingestDocumentWrappers), ActionListener.wrap(response -> {
MultiGetItemResponse[] multiGetItemResponses = response.getResponses();
if (multiGetItemResponses == null || multiGetItemResponses.length == 0) {
super.batchExecute(ingestDocumentWrappers, handler);
return;
}
Map<String, Map<String, Object>> existingDocuments = createDocumentMap(multiGetItemResponses);
if (this.batchSize >= ingestDocumentWrappers.size()) {
subBatchExecuteWithFilter(ingestDocumentWrappers, existingDocuments, textEmbeddingInferenceFilter, handler);
return;
}
List<List<IngestDocumentWrapper>> batches = cutBatches(ingestDocumentWrappers);
int size = ingestDocumentWrappers.size();
AtomicInteger counter = new AtomicInteger(size);
List<IngestDocumentWrapper> allResults = Collections.synchronizedList(new ArrayList());
for (List<IngestDocumentWrapper> batch : batches) {
this.subBatchExecuteWithFilter(batch, existingDocuments, textEmbeddingInferenceFilter, (batchResults) -> {
allResults.addAll(batchResults);
if (counter.addAndGet(-batchResults.size()) == 0) {
handler.accept(allResults);
}
assert counter.get() >= 0 : "counter is negative";
});
}
}, e -> { handler.accept(null); }));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import java.util.List;
import java.util.ListIterator;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;

/**
Expand Down Expand Up @@ -71,7 +72,7 @@ public abstract Object filterInferenceValue(
* @return The processed value or null if embeddings are reused
*/

public abstract Object copyEmbeddingForSingleValue(
public abstract Object copyEmbedding(
String embeddingKey,
Object processValue,
Object existingValue,
Expand All @@ -80,27 +81,6 @@ public abstract Object copyEmbeddingForSingleValue(
int index
);

/**
* Abstract method to filter and compare lists of values.
* If all elements in the list are identical between the new and existing metadata maps, embeddings are copied,
* and an empty list is returned to indicate no further processing is required.
*
* @param embeddingKey The dot-notation path for the embedding field
* @param processList The list of values to be checked for potential embedding reuse
* @param existingList The list of existing values for comparison
* @param embeddingList The list of existing embeddings
* @param sourceAndMetadataMap The metadata map of the new document.
* @return A processed list or an empty list if embeddings are reused.
*/

public abstract List<Object> copyEmbeddingForMultipleValues(
String embeddingKey,
List<Object> processList,
List<Object> existingList,
List<Object> embeddingList,
Map<String, Object> sourceAndMetadataMap
);

/**
* This method navigates through the nested structure, checking each key-value pair recursively. It supports:
* Map values: Processed recursively using this method.
Expand Down Expand Up @@ -155,7 +135,7 @@ private Map<String, Object> filter(
);
filteredProcessMap.put(key, filteredInnerMap.isEmpty() ? null : filteredInnerMap);
} else if (value instanceof List) {
List<Object> processedList = filterListValue(
Object processedList = filterListValue(
currentPath,
ProcessorUtils.unsafeCastToObjectList(value),
sourceAndMetadataMap,
Expand Down Expand Up @@ -194,13 +174,15 @@ protected List<Object> filterListValue(
List<Object> existingListValue = ProcessorUtils.unsafeCastToObjectList(existingListOptional.get());
if (existingListValue.getFirst() instanceof List) {
// in case of nested list, compare and copy by list comparison
return copyEmbeddingForMultipleValues(
Object processedList = copyEmbedding(
embeddingKey,
processList,
ProcessorUtils.unsafeCastToObjectList(existingListValue.getFirst()),
ProcessorUtils.unsafeCastToObjectList(embeddingListOptional.get()),
sourceAndMetadataMap
existingListValue.getFirst(),
embeddingListOptional.get(),
sourceAndMetadataMap,
-1
);
return Objects.nonNull(processedList) ? ProcessorUtils.unsafeCastToObjectList(processedList) : null;
} else {
// in case of List of Maps, compare each map entry in list
return filterMapValuesInList(
Expand Down Expand Up @@ -231,20 +213,22 @@ public List<Object> filterMapValuesInList(
Map<String, Object> sourceAndMetadataMap
) {
List<Object> filteredList = new ArrayList<>();
ListIterator<Object> processListIterator = processList.listIterator();
ListIterator<Object> existingListIterator = existingList.listIterator();
ListIterator<Object> embeddingListIterator = embeddingList.listIterator();
int index = 0;
while (processListIterator.hasNext() && existingListIterator.hasNext() && embeddingListIterator.hasNext()) {
Object processedItem = copyEmbeddingForSingleValue(
embeddingKey,
processListIterator.next(),
existingListIterator.next(),
embeddingListIterator.next(),
sourceAndMetadataMap,
index++
);
filteredList.add(processedItem);
for (Object processValue : processList) {
if (Objects.nonNull(processValue) && existingListIterator.hasNext() && embeddingListIterator.hasNext()) {
Object processedItem = copyEmbedding(
embeddingKey,
processValue,
existingListIterator.next(),
embeddingListIterator.next(),
sourceAndMetadataMap,
index
);
filteredList.add(processedItem);
}
index++;
}
return filteredList;
}
Expand Down
Loading
Loading