Skip to content

Commit a257a0c

Browse files
committed
implement batch document update scneario for text embedding processor
Signed-off-by: will-hwang <sang7239@gmail.com>
1 parent 1a6e58e commit a257a0c

10 files changed

+384
-92
lines changed

CHANGELOG.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
55

66
## [Unreleased 3.0](https://github.com/opensearch-project/neural-search/compare/2.x...HEAD)
77
### Features
8-
- Add Optimized Text Embedding Processor ([#1191](https://github.com/opensearch-project/neural-search/pull/1191))
8+
- Optimizing embedding generation in text embedding processor ([#1191](https://github.com/opensearch-project/neural-search/pull/1191))
99
### Enhancements
1010
- Set neural-search plugin 3.0.0 baseline JDK version to JDK-21 ([#838](https://github.com/opensearch-project/neural-search/pull/838))
1111
- Support different embedding types in model's response ([#1007](https://github.com/opensearch-project/neural-search/pull/1007))

src/main/java/org/opensearch/neuralsearch/processor/InferenceProcessor.java

+47-4
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,8 @@
2626
import org.apache.commons.lang3.StringUtils;
2727
import org.apache.commons.lang3.tuple.ImmutablePair;
2828
import org.apache.commons.lang3.tuple.Pair;
29+
import org.opensearch.action.get.MultiGetItemResponse;
30+
import org.opensearch.action.get.MultiGetRequest;
2931
import org.opensearch.common.collect.Tuple;
3032
import org.opensearch.core.action.ActionListener;
3133
import org.opensearch.core.common.util.CollectionUtils;
@@ -54,6 +56,8 @@ public abstract class InferenceProcessor extends AbstractBatchingProcessor {
5456

5557
public static final String MODEL_ID_FIELD = "model_id";
5658
public static final String FIELD_MAP_FIELD = "field_map";
59+
public static final String INDEX_FIELD = "_index";
60+
public static final String ID_FIELD = "_id";
5761
private static final BiFunction<Object, Object, Object> REMAPPING_FUNCTION = (v1, v2) -> {
5862
if (v1 instanceof Collection && v2 instanceof Collection) {
5963
((Collection) v1).addAll((Collection) v2);
@@ -182,6 +186,15 @@ public void subBatchExecute(List<IngestDocumentWrapper> ingestDocumentWrappers,
182186
handler.accept(ingestDocumentWrappers);
183187
return;
184188
}
189+
doSubBatchExecute(ingestDocumentWrappers, inferenceList, dataForInferences, handler);
190+
}
191+
192+
protected void doSubBatchExecute(
193+
List<IngestDocumentWrapper> ingestDocumentWrappers,
194+
List<String> inferenceList,
195+
List<DataForInference> dataForInferences,
196+
Consumer<List<IngestDocumentWrapper>> handler
197+
) {
185198
Tuple<List<String>, Map<Integer, Integer>> sortedResult = sortByLengthAndReturnOriginalOrder(inferenceList);
186199
inferenceList = sortedResult.v1();
187200
Map<Integer, Integer> originalOrder = sortedResult.v2();
@@ -238,7 +251,7 @@ private List<?> restoreToOriginalOrder(List<?> results, Map<Integer, Integer> or
238251
return sortedResults;
239252
}
240253

241-
private List<String> constructInferenceTexts(List<DataForInference> dataForInferences) {
254+
protected List<String> constructInferenceTexts(List<DataForInference> dataForInferences) {
242255
List<String> inferenceTexts = new ArrayList<>();
243256
for (DataForInference dataForInference : dataForInferences) {
244257
if (dataForInference.getIngestDocumentWrapper().getException() != null
@@ -250,7 +263,7 @@ private List<String> constructInferenceTexts(List<DataForInference> dataForInfer
250263
return inferenceTexts;
251264
}
252265

253-
private List<DataForInference> getDataForInference(List<IngestDocumentWrapper> ingestDocumentWrappers) {
266+
protected List<DataForInference> getDataForInference(List<IngestDocumentWrapper> ingestDocumentWrappers) {
254267
List<DataForInference> dataForInferences = new ArrayList<>();
255268
for (IngestDocumentWrapper ingestDocumentWrapper : ingestDocumentWrappers) {
256269
Map<String, Object> processMap = null;
@@ -272,7 +285,7 @@ private List<DataForInference> getDataForInference(List<IngestDocumentWrapper> i
272285

273286
@Getter
274287
@AllArgsConstructor
275-
private static class DataForInference {
288+
protected static class DataForInference {
276289
private final IngestDocumentWrapper ingestDocumentWrapper;
277290
private final Map<String, Object> processMap;
278291
private final List<String> inferenceList;
@@ -415,6 +428,36 @@ protected void setVectorFieldsToDocument(IngestDocument ingestDocument, Map<Stri
415428
nlpResult.forEach(ingestDocument::setFieldValue);
416429
}
417430

431+
/**
432+
* This method creates a MultiGetRequest from a list of ingest documents to be fetched for comparison
433+
* @param ingestDocumentWrappers, list of ingest documents
434+
* */
435+
protected MultiGetRequest buildMultiGetRequest(List<IngestDocumentWrapper> ingestDocumentWrappers) {
436+
MultiGetRequest multiGetRequest = new MultiGetRequest();
437+
for (IngestDocumentWrapper ingestDocumentWrapper : ingestDocumentWrappers) {
438+
Object index = ingestDocumentWrapper.getIngestDocument().getSourceAndMetadata().get(INDEX_FIELD);
439+
Object id = ingestDocumentWrapper.getIngestDocument().getSourceAndMetadata().get(ID_FIELD);
440+
if (Objects.nonNull(index) && Objects.nonNull(id)) {
441+
multiGetRequest.add(index.toString(), id.toString());
442+
}
443+
}
444+
return multiGetRequest;
445+
}
446+
447+
/**
448+
* This method creates a map of documents from MultiGetItemResponse where the key is document ID and value is corresponding document
449+
* @param multiGetItemResponses, array of responses from Multi Get Request
450+
* */
451+
protected Map<String, Map<String, Object>> createDocumentMap(MultiGetItemResponse[] multiGetItemResponses) {
452+
Map<String, Map<String, Object>> existingDocuments = new HashMap<>();
453+
for (MultiGetItemResponse item : multiGetItemResponses) {
454+
String id = item.getId();
455+
Map<String, Object> existingDocument = item.getResponse().getSourceAsMap();
456+
existingDocuments.put(id, existingDocument);
457+
}
458+
return existingDocuments;
459+
}
460+
418461
@SuppressWarnings({ "unchecked" })
419462
@VisibleForTesting
420463
Map<String, Object> buildNLPResult(Map<String, Object> processorMap, List<?> results, Map<String, Object> sourceAndMetadataMap) {
@@ -582,7 +625,7 @@ private List<Map<String, Object>> buildNLPResultForListType(List<String> sourceV
582625
List<Map<String, Object>> keyToResult = new ArrayList<>();
583626
sourceValue.stream()
584627
.filter(Objects::nonNull) // explicit null check is required since sourceValue can contain null values in cases where
585-
// sourceValue has been filtered
628+
// sourceValue has been filtered
586629
.forEachOrdered(x -> keyToResult.add(ImmutableMap.of(listTypeNestedMapKey, results.get(indexWrapper.index++))));
587630
return keyToResult;
588631
}

src/main/java/org/opensearch/neuralsearch/processor/TextEmbeddingProcessor.java

+62
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
*/
55
package org.opensearch.neuralsearch.processor;
66

7+
import java.util.ArrayList;
8+
import java.util.Collections;
79
import java.util.List;
810
import java.util.Map;
911
import java.util.Objects;
@@ -13,10 +15,14 @@
1315

1416
import org.opensearch.action.get.GetAction;
1517
import org.opensearch.action.get.GetRequest;
18+
import org.opensearch.action.get.MultiGetAction;
19+
import org.opensearch.action.get.MultiGetItemResponse;
1620
import org.opensearch.cluster.service.ClusterService;
1721
import org.opensearch.core.action.ActionListener;
22+
import org.opensearch.core.common.util.CollectionUtils;
1823
import org.opensearch.env.Environment;
1924
import org.opensearch.ingest.IngestDocument;
25+
import org.opensearch.ingest.IngestDocumentWrapper;
2026
import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor;
2127

2228
import lombok.extern.log4j.Log4j2;
@@ -106,4 +112,60 @@ public void doBatchExecute(List<String> inferenceList, Consumer<List<?>> handler
106112
ActionListener.wrap(handler::accept, onException)
107113
);
108114
}
115+
116+
@Override
117+
public void subBatchExecute(List<IngestDocumentWrapper> ingestDocumentWrappers, Consumer<List<IngestDocumentWrapper>> handler) {
118+
if (CollectionUtils.isEmpty(ingestDocumentWrappers)) {
119+
handler.accept(Collections.emptyList());
120+
return;
121+
}
122+
123+
List<DataForInference> dataForInferences = getDataForInference(ingestDocumentWrappers);
124+
List<String> inferenceList = constructInferenceTexts(dataForInferences);
125+
if (inferenceList.isEmpty()) {
126+
handler.accept(ingestDocumentWrappers);
127+
return;
128+
}
129+
if (skipExisting == false) {
130+
doSubBatchExecute(ingestDocumentWrappers, inferenceList, dataForInferences, handler);
131+
return;
132+
}
133+
openSearchClient.execute(MultiGetAction.INSTANCE, buildMultiGetRequest(ingestDocumentWrappers), ActionListener.wrap(response -> {
134+
MultiGetItemResponse[] multiGetItemResponses = response.getResponses();
135+
if (multiGetItemResponses == null || multiGetItemResponses.length == 0) {
136+
doSubBatchExecute(ingestDocumentWrappers, inferenceList, dataForInferences, handler);
137+
return;
138+
}
139+
Map<String, Map<String, Object>> existingDocuments = createDocumentMap(multiGetItemResponses);
140+
List<DataForInference> filteredDataForInference = new ArrayList<>();
141+
for (DataForInference dataForInference : dataForInferences) {
142+
IngestDocumentWrapper ingestDocumentWrapper = dataForInference.getIngestDocumentWrapper();
143+
Map<String, Object> processMap = dataForInference.getProcessMap();
144+
Map<String, Object> document = ingestDocumentWrapper.getIngestDocument().getSourceAndMetadata();
145+
Object id = document.get(ID_FIELD);
146+
// insert non-filtered dataForInference if existing document does not exist
147+
if (Objects.isNull(id) || existingDocuments.containsKey(id.toString()) == false) {
148+
filteredDataForInference.add(dataForInference);
149+
continue;
150+
}
151+
// filter dataForInference when existing document exists
152+
String docId = id.toString();
153+
Map<String, Object> existingDocument = existingDocuments.get(docId);
154+
Map<String, Object> filteredProcessMap = textEmbeddingInferenceFilter.filter(existingDocument, document, processMap);
155+
List<String> filteredInferenceList = createInferenceList(filteredProcessMap);
156+
filteredDataForInference.add(new DataForInference(ingestDocumentWrapper, filteredProcessMap, filteredInferenceList));
157+
}
158+
List<String> filteredInferenceList = constructInferenceTexts(filteredDataForInference);
159+
if (filteredInferenceList.isEmpty()) {
160+
handler.accept(ingestDocumentWrappers);
161+
} else {
162+
doSubBatchExecute(ingestDocumentWrappers, filteredInferenceList, filteredDataForInference, handler);
163+
}
164+
}, e -> {
165+
for (IngestDocumentWrapper ingestDocumentWrapper : ingestDocumentWrappers) {
166+
ingestDocumentWrapper.update(ingestDocumentWrapper.getIngestDocument(), e);
167+
}
168+
handler.accept(ingestDocumentWrappers);
169+
}));
170+
}
109171
}

src/main/java/org/opensearch/neuralsearch/processor/optimization/InferenceFilter.java

+22-38
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
import java.util.List;
1414
import java.util.ListIterator;
1515
import java.util.Map;
16+
import java.util.Objects;
1617
import java.util.Optional;
1718

1819
/**
@@ -71,7 +72,7 @@ public abstract Object filterInferenceValue(
7172
* @return The processed value or null if embeddings are reused
7273
*/
7374

74-
public abstract Object copyEmbeddingForSingleValue(
75+
public abstract Object copyEmbedding(
7576
String embeddingKey,
7677
Object processValue,
7778
Object existingValue,
@@ -80,27 +81,6 @@ public abstract Object copyEmbeddingForSingleValue(
8081
int index
8182
);
8283

83-
/**
84-
* Abstract method to filter and compare lists of values.
85-
* If all elements in the list are identical between the new and existing metadata maps, embeddings are copied,
86-
* and an empty list is returned to indicate no further processing is required.
87-
*
88-
* @param embeddingKey The dot-notation path for the embedding field
89-
* @param processList The list of values to be checked for potential embedding reuse
90-
* @param existingList The list of existing values for comparison
91-
* @param embeddingList The list of existing embeddings
92-
* @param sourceAndMetadataMap The metadata map of the new document.
93-
* @return A processed list or an empty list if embeddings are reused.
94-
*/
95-
96-
public abstract List<Object> copyEmbeddingForMultipleValues(
97-
String embeddingKey,
98-
List<Object> processList,
99-
List<Object> existingList,
100-
List<Object> embeddingList,
101-
Map<String, Object> sourceAndMetadataMap
102-
);
103-
10484
/**
10585
* This method navigates through the nested structure, checking each key-value pair recursively. It supports:
10686
* Map values: Processed recursively using this method.
@@ -155,7 +135,7 @@ private Map<String, Object> filter(
155135
);
156136
filteredProcessMap.put(key, filteredInnerMap.isEmpty() ? null : filteredInnerMap);
157137
} else if (value instanceof List) {
158-
List<Object> processedList = filterListValue(
138+
Object processedList = filterListValue(
159139
currentPath,
160140
ProcessorUtils.unsafeCastToObjectList(value),
161141
sourceAndMetadataMap,
@@ -194,13 +174,15 @@ protected List<Object> filterListValue(
194174
List<Object> existingListValue = ProcessorUtils.unsafeCastToObjectList(existingListOptional.get());
195175
if (existingListValue.getFirst() instanceof List) {
196176
// in case of nested list, compare and copy by list comparison
197-
return copyEmbeddingForMultipleValues(
177+
Object processedList = copyEmbedding(
198178
embeddingKey,
199179
processList,
200-
ProcessorUtils.unsafeCastToObjectList(existingListValue.getFirst()),
201-
ProcessorUtils.unsafeCastToObjectList(embeddingListOptional.get()),
202-
sourceAndMetadataMap
180+
existingListValue.getFirst(),
181+
embeddingListOptional.get(),
182+
sourceAndMetadataMap,
183+
-1
203184
);
185+
return Objects.nonNull(processedList) ? ProcessorUtils.unsafeCastToObjectList(processedList) : null;
204186
} else {
205187
// in case of List of Maps, compare each map entry in list
206188
return filterMapValuesInList(
@@ -231,20 +213,22 @@ public List<Object> filterMapValuesInList(
231213
Map<String, Object> sourceAndMetadataMap
232214
) {
233215
List<Object> filteredList = new ArrayList<>();
234-
ListIterator<Object> processListIterator = processList.listIterator();
235216
ListIterator<Object> existingListIterator = existingList.listIterator();
236217
ListIterator<Object> embeddingListIterator = embeddingList.listIterator();
237218
int index = 0;
238-
while (processListIterator.hasNext() && existingListIterator.hasNext() && embeddingListIterator.hasNext()) {
239-
Object processedItem = copyEmbeddingForSingleValue(
240-
embeddingKey,
241-
processListIterator.next(),
242-
existingListIterator.next(),
243-
embeddingListIterator.next(),
244-
sourceAndMetadataMap,
245-
index++
246-
);
247-
filteredList.add(processedItem);
219+
for (Object processValue : processList) {
220+
if (Objects.nonNull(processValue) && existingListIterator.hasNext() && embeddingListIterator.hasNext()) {
221+
Object processedItem = copyEmbedding(
222+
embeddingKey,
223+
processValue,
224+
existingListIterator.next(),
225+
embeddingListIterator.next(),
226+
sourceAndMetadataMap,
227+
index
228+
);
229+
filteredList.add(processedItem);
230+
}
231+
index++;
248232
}
249233
return filteredList;
250234
}

src/main/java/org/opensearch/neuralsearch/processor/optimization/TextEmbeddingInferenceFilter.java

+3-27
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,6 @@
77
import lombok.extern.log4j.Log4j2;
88
import org.opensearch.neuralsearch.processor.util.ProcessorUtils;
99

10-
import java.util.Collections;
11-
import java.util.List;
1210
import java.util.Map;
1311
import java.util.Objects;
1412
import java.util.Optional;
@@ -47,7 +45,7 @@ public Object filterInferenceValue(
4745
Optional<Object> existingValueOptional = ProcessorUtils.getValueFromSource(existingSourceAndMetadataMap, textPath);
4846
Optional<Object> embeddingValueOptional = ProcessorUtils.getValueFromSource(existingSourceAndMetadataMap, embeddingKey);
4947
if (existingValueOptional.isPresent() && embeddingValueOptional.isPresent()) {
50-
return copyEmbeddingForSingleValue(
48+
return copyEmbedding(
5149
embeddingKey,
5250
processValue,
5351
existingValueOptional.get(),
@@ -60,14 +58,14 @@ public Object filterInferenceValue(
6058
}
6159

6260
/**
63-
* Copy a single value by checking if the text is identical in both the existing and new document.
61+
* Copy a single value by checking if the given texts is identical in both the existing and new document.
6462
* If the text matches, the corresponding embedding is copied, and null is returned, indicating no further
6563
* processing is required.
6664
*
6765
* @return null if embeddings are reused; the processValue otherwise.
6866
*/
6967
@Override
70-
public Object copyEmbeddingForSingleValue(
68+
public Object copyEmbedding(
7169
String embeddingKey,
7270
Object processValue,
7371
Object existingValue,
@@ -83,26 +81,4 @@ public Object copyEmbeddingForSingleValue(
8381
// processValue and existingValue are different, return processValue to be included in process map
8482
return processValue;
8583
}
86-
87-
/**
88-
* Copy values in list by checking if all texts in list are identical in both the existing and new documents.
89-
* If lists are equal, the corresponding embeddings are copied
90-
* @return empty list if embeddings are reused; processList otherwise.
91-
*/
92-
@Override
93-
public List<Object> copyEmbeddingForMultipleValues(
94-
String embeddingKey,
95-
List<Object> processList,
96-
List<Object> existingList,
97-
List<Object> embeddingList,
98-
Map<String, Object> sourceAndMetadataMap
99-
) {
100-
if (Objects.equals(processList, existingList)) {
101-
ProcessorUtils.setValueToSource(sourceAndMetadataMap, embeddingKey, embeddingList);
102-
// if successfully copied, return empty list to be filtered out from process map
103-
return Collections.emptyList();
104-
}
105-
// source list and existing list are different, return processList to be included in process map
106-
return processList;
107-
}
10884
}

0 commit comments

Comments
 (0)