Skip to content

Commit 8902065

Browse files
committed
implement single document update scenario in text embedding processor
Signed-off-by: Will Hwang <sang7239@gmail.com>
1 parent 628cb64 commit 8902065

24 files changed

+2344
-136
lines changed

CHANGELOG.md

+1
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
55

66
## [Unreleased 3.0](https://github.com/opensearch-project/neural-search/compare/2.x...HEAD)
77
### Features
8+
- Add Optimized Text Embedding Processor ([#1191](https://github.com/opensearch-project/neural-search/pull/1191))
89
### Enhancements
910
- Set neural-search plugin 3.0.0 baseline JDK version to JDK-21 ([#838](https://github.com/opensearch-project/neural-search/pull/838))
1011
- Support different embedding types in model's response ([#1007](https://github.com/opensearch-project/neural-search/pull/1007))

src/main/java/org/opensearch/neuralsearch/plugin/NeuralSearch.java

+6-1
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,12 @@ public Map<String, Processor.Factory> getProcessors(Processor.Parameters paramet
127127
clientAccessor = new MLCommonsClientAccessor(new MachineLearningNodeClient(parameters.client));
128128
return Map.of(
129129
TextEmbeddingProcessor.TYPE,
130-
new TextEmbeddingProcessorFactory(clientAccessor, parameters.env, parameters.ingestService.getClusterService()),
130+
new TextEmbeddingProcessorFactory(
131+
parameters.client,
132+
clientAccessor,
133+
parameters.env,
134+
parameters.ingestService.getClusterService()
135+
),
131136
SparseEncodingProcessor.TYPE,
132137
new SparseEncodingProcessorFactory(clientAccessor, parameters.env, parameters.ingestService.getClusterService()),
133138
TextImageEmbeddingProcessor.TYPE,

src/main/java/org/opensearch/neuralsearch/processor/InferenceProcessor.java

+30-3
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
import org.apache.commons.lang3.tuple.ImmutablePair;
2828
import org.apache.commons.lang3.tuple.Pair;
2929
import org.opensearch.common.collect.Tuple;
30+
import org.opensearch.core.action.ActionListener;
3031
import org.opensearch.core.common.util.CollectionUtils;
3132
import org.opensearch.cluster.service.ClusterService;
3233
import org.opensearch.env.Environment;
@@ -118,7 +119,7 @@ private void validateEmbeddingConfiguration(Map<String, Object> fieldMap) {
118119

119120
public abstract void doExecute(
120121
IngestDocument ingestDocument,
121-
Map<String, Object> ProcessMap,
122+
Map<String, Object> processMap,
122123
List<String> inferenceList,
123124
BiConsumer<IngestDocument, Exception> handler
124125
);
@@ -278,7 +279,7 @@ private static class DataForInference {
278279
}
279280

280281
@SuppressWarnings({ "unchecked" })
281-
private List<String> createInferenceList(Map<String, Object> knnKeyMap) {
282+
protected List<String> createInferenceList(Map<String, Object> knnKeyMap) {
282283
List<String> texts = new ArrayList<>();
283284
knnKeyMap.entrySet().stream().filter(knnMapEntry -> knnMapEntry.getValue() != null).forEach(knnMapEntry -> {
284285
Object sourceValue = knnMapEntry.getValue();
@@ -579,11 +580,37 @@ private Map<String, Object> getSourceMapBySourceAndMetadataMap(String processorK
579580

580581
private List<Map<String, Object>> buildNLPResultForListType(List<String> sourceValue, List<?> results, IndexWrapper indexWrapper) {
581582
List<Map<String, Object>> keyToResult = new ArrayList<>();
582-
IntStream.range(0, sourceValue.size())
583+
sourceValue.stream()
584+
.filter(Objects::nonNull) // explicit null check is required since sourceValue can contain null values in cases where
585+
// sourceValue has been filtered
583586
.forEachOrdered(x -> keyToResult.add(ImmutableMap.of(listTypeNestedMapKey, results.get(indexWrapper.index++))));
584587
return keyToResult;
585588
}
586589

590+
/**
591+
* This method invokes inference call through mlCommonsClientAccessor and populates retrieved embeddings to ingestDocument
592+
*
593+
* @param ingestDocument ingestDocument to populate embeddings to
594+
* @param processMap map indicating the path in ingestDocument to populate embeddings
595+
* @param inferenceList list of texts to be model inference
596+
* @param handler SourceAndMetadataMap of ingestDocument Document
597+
*
598+
*/
599+
protected void makeInferenceCall(
600+
IngestDocument ingestDocument,
601+
Map<String, Object> processMap,
602+
List<String> inferenceList,
603+
BiConsumer<IngestDocument, Exception> handler
604+
) {
605+
mlCommonsClientAccessor.inferenceSentences(
606+
TextInferenceRequest.builder().modelId(this.modelId).inputTexts(inferenceList).build(),
607+
ActionListener.wrap(vectors -> {
608+
setVectorFieldsToDocument(ingestDocument, processMap, vectors);
609+
handler.accept(ingestDocument, null);
610+
}, e -> { handler.accept(null, e); })
611+
);
612+
}
613+
587614
@Override
588615
public String getType() {
589616
return type;

src/main/java/org/opensearch/neuralsearch/processor/TextEmbeddingProcessor.java

+50-7
Original file line numberDiff line numberDiff line change
@@ -6,16 +6,22 @@
66

77
import java.util.List;
88
import java.util.Map;
9+
import java.util.Objects;
910
import java.util.function.BiConsumer;
1011
import java.util.function.Consumer;
12+
import java.util.stream.Collectors;
1113

14+
import org.opensearch.action.get.GetAction;
15+
import org.opensearch.action.get.GetRequest;
1216
import org.opensearch.cluster.service.ClusterService;
1317
import org.opensearch.core.action.ActionListener;
1418
import org.opensearch.env.Environment;
1519
import org.opensearch.ingest.IngestDocument;
1620
import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor;
1721

1822
import lombok.extern.log4j.Log4j2;
23+
import org.opensearch.neuralsearch.processor.optimization.TextEmbeddingInferenceFilter;
24+
import org.opensearch.transport.client.OpenSearchClient;
1925

2026
/**
2127
* This processor is used for user input data text embedding processing, model_id can be used to indicate which model user use,
@@ -26,34 +32,71 @@ public final class TextEmbeddingProcessor extends InferenceProcessor {
2632

2733
public static final String TYPE = "text_embedding";
2834
public static final String LIST_TYPE_NESTED_MAP_KEY = "knn";
35+
public static final String SKIP_EXISTING = "skip_existing";
36+
public static final boolean DEFAULT_SKIP_EXISTING = false;
37+
private static final String INDEX_FIELD = "_index";
38+
private static final String ID_FIELD = "_id";
39+
private final OpenSearchClient openSearchClient;
40+
private final boolean skipExisting;
41+
private final TextEmbeddingInferenceFilter textEmbeddingInferenceFilter;
2942

3043
public TextEmbeddingProcessor(
3144
String tag,
3245
String description,
3346
int batchSize,
3447
String modelId,
3548
Map<String, Object> fieldMap,
49+
boolean skipExisting,
50+
TextEmbeddingInferenceFilter textEmbeddingInferenceFilter,
51+
OpenSearchClient openSearchClient,
3652
MLCommonsClientAccessor clientAccessor,
3753
Environment environment,
3854
ClusterService clusterService
3955
) {
4056
super(tag, description, batchSize, TYPE, LIST_TYPE_NESTED_MAP_KEY, modelId, fieldMap, clientAccessor, environment, clusterService);
57+
this.skipExisting = skipExisting;
58+
this.textEmbeddingInferenceFilter = textEmbeddingInferenceFilter;
59+
this.openSearchClient = openSearchClient;
4160
}
4261

4362
@Override
4463
public void doExecute(
4564
IngestDocument ingestDocument,
46-
Map<String, Object> ProcessMap,
65+
Map<String, Object> processMap,
4766
List<String> inferenceList,
4867
BiConsumer<IngestDocument, Exception> handler
4968
) {
50-
mlCommonsClientAccessor.inferenceSentences(
51-
TextInferenceRequest.builder().modelId(this.modelId).inputTexts(inferenceList).build(),
52-
ActionListener.wrap(vectors -> {
53-
setVectorFieldsToDocument(ingestDocument, ProcessMap, vectors);
69+
// skip existing flag is turned off. Call model inference without filtering
70+
if (skipExisting == false) {
71+
makeInferenceCall(ingestDocument, processMap, inferenceList, handler);
72+
return;
73+
}
74+
// if skipExisting flag is turned on, eligible inference texts will be compared and filtered after embeddings are copied
75+
String index = ingestDocument.getSourceAndMetadata().get(INDEX_FIELD).toString();
76+
String id = ingestDocument.getSourceAndMetadata().get(ID_FIELD).toString();
77+
openSearchClient.execute(GetAction.INSTANCE, new GetRequest(index, id), ActionListener.wrap(response -> {
78+
final Map<String, Object> existingDocument = response.getSourceAsMap();
79+
if (existingDocument == null || existingDocument.isEmpty()) {
80+
makeInferenceCall(ingestDocument, processMap, inferenceList, handler);
81+
return;
82+
}
83+
// filter given ProcessMap by comparing existing document with ingestDocument
84+
Map<String, Object> filteredProcessMap = textEmbeddingInferenceFilter.filter(
85+
existingDocument,
86+
ingestDocument.getSourceAndMetadata(),
87+
processMap
88+
);
89+
// create inference list based on filtered ProcessMap
90+
List<String> filteredInferenceList = createInferenceList(filteredProcessMap).stream()
91+
.filter(Objects::nonNull)
92+
.collect(Collectors.toList());
93+
if (filteredInferenceList.isEmpty()) {
5494
handler.accept(ingestDocument, null);
55-
}, e -> { handler.accept(null, e); })
56-
);
95+
} else {
96+
makeInferenceCall(ingestDocument, filteredProcessMap, filteredInferenceList, handler);
97+
}
98+
99+
}, e -> { handler.accept(null, e); }));
57100
}
58101

59102
@Override

src/main/java/org/opensearch/neuralsearch/processor/factory/TextEmbeddingProcessorFactory.java

+25-2
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,11 @@
44
*/
55
package org.opensearch.neuralsearch.processor.factory;
66

7+
import static org.opensearch.ingest.ConfigurationUtils.readBooleanProperty;
78
import static org.opensearch.ingest.ConfigurationUtils.readMap;
89
import static org.opensearch.ingest.ConfigurationUtils.readStringProperty;
10+
import static org.opensearch.neuralsearch.processor.TextEmbeddingProcessor.SKIP_EXISTING;
11+
import static org.opensearch.neuralsearch.processor.TextEmbeddingProcessor.DEFAULT_SKIP_EXISTING;
912
import static org.opensearch.neuralsearch.processor.TextEmbeddingProcessor.TYPE;
1013
import static org.opensearch.neuralsearch.processor.TextEmbeddingProcessor.MODEL_ID_FIELD;
1114
import static org.opensearch.neuralsearch.processor.TextEmbeddingProcessor.FIELD_MAP_FIELD;
@@ -17,24 +20,30 @@
1720
import org.opensearch.ingest.AbstractBatchingProcessor;
1821
import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor;
1922
import org.opensearch.neuralsearch.processor.TextEmbeddingProcessor;
23+
import org.opensearch.neuralsearch.processor.optimization.TextEmbeddingInferenceFilter;
24+
import org.opensearch.transport.client.OpenSearchClient;
2025

2126
/**
2227
* Factory for text embedding ingest processor for ingestion pipeline. Instantiates processor based on user provided input.
2328
*/
2429
public final class TextEmbeddingProcessorFactory extends AbstractBatchingProcessor.Factory {
2530

31+
private final OpenSearchClient openSearchClient;
32+
2633
private final MLCommonsClientAccessor clientAccessor;
2734

2835
private final Environment environment;
2936

3037
private final ClusterService clusterService;
3138

3239
public TextEmbeddingProcessorFactory(
40+
final OpenSearchClient openSearchClient,
3341
final MLCommonsClientAccessor clientAccessor,
3442
final Environment environment,
3543
final ClusterService clusterService
3644
) {
3745
super(TYPE);
46+
this.openSearchClient = openSearchClient;
3847
this.clientAccessor = clientAccessor;
3948
this.environment = environment;
4049
this.clusterService = clusterService;
@@ -43,7 +52,21 @@ public TextEmbeddingProcessorFactory(
4352
@Override
4453
protected AbstractBatchingProcessor newProcessor(String tag, String description, int batchSize, Map<String, Object> config) {
4554
String modelId = readStringProperty(TYPE, tag, config, MODEL_ID_FIELD);
46-
Map<String, Object> filedMap = readMap(TYPE, tag, config, FIELD_MAP_FIELD);
47-
return new TextEmbeddingProcessor(tag, description, batchSize, modelId, filedMap, clientAccessor, environment, clusterService);
55+
Map<String, Object> fieldMap = readMap(TYPE, tag, config, FIELD_MAP_FIELD);
56+
boolean skipExisting = readBooleanProperty(TYPE, tag, config, SKIP_EXISTING, DEFAULT_SKIP_EXISTING);
57+
TextEmbeddingInferenceFilter textEmbeddingInferenceFilter = new TextEmbeddingInferenceFilter(fieldMap);
58+
return new TextEmbeddingProcessor(
59+
tag,
60+
description,
61+
batchSize,
62+
modelId,
63+
fieldMap,
64+
skipExisting,
65+
textEmbeddingInferenceFilter,
66+
openSearchClient,
67+
clientAccessor,
68+
environment,
69+
clusterService
70+
);
4871
}
4972
}

0 commit comments

Comments
 (0)