Skip to content

Commit 54ac672

Browse files
authored
fix map type validation issue in processors (#687)
* fix map type validation issue in processors Signed-off-by: zane-neo <zaniu@amazon.com> * fix test failures on main branch Signed-off-by: zane-neo <zaniu@amazon.com> * Fix potential NPE issue in chunking processor; add changee log Signed-off-by: zane-neo <zaniu@amazon.com> * Fix failure tests Signed-off-by: zane-neo <zaniu@amazon.com> * Address comments and add one more UT to cover uncovered line Signed-off-by: zane-neo <zaniu@amazon.com> * Address comments Signed-off-by: zane-neo <zaniu@amazon.com> * Add more UTs Signed-off-by: zane-neo <zaniu@amazon.com> * fix failure ITs Signed-off-by: zane-neo <zaniu@amazon.com> * Add public method with default depth parameter value Signed-off-by: zane-neo <zaniu@amazon.com> * rebase latest code Signed-off-by: zane-neo <zaniu@amazon.com> * address comments Signed-off-by: zane-neo <zaniu@amazon.com> * address comment Signed-off-by: zane-neo <zaniu@amazon.com> --------- Signed-off-by: zane-neo <zaniu@amazon.com>
1 parent 8705980 commit 54ac672

18 files changed

+648
-240
lines changed

CHANGELOG.md

+1
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
2323
- Optimize max score calculation in the Query Phase of the Hybrid Search ([765](https://github.com/opensearch-project/neural-search/pull/765))
2424
### Bug Fixes
2525
- Total hit count fix in Hybrid Query ([756](https://github.com/opensearch-project/neural-search/pull/756))
26+
- Fix map type validation issue in multiple pipeline processors ([#661](https://github.com/opensearch-project/neural-search/pull/661))
2627
### Infrastructure
2728
- Disable memory circuit breaker for integ tests ([#770](https://github.com/opensearch-project/neural-search/pull/770))
2829
### Documentation

src/main/java/org/opensearch/neuralsearch/plugin/NeuralSearch.java

+2-2
Original file line numberDiff line numberDiff line change
@@ -113,9 +113,9 @@ public Map<String, Processor.Factory> getProcessors(Processor.Parameters paramet
113113
clientAccessor = new MLCommonsClientAccessor(new MachineLearningNodeClient(parameters.client));
114114
return Map.of(
115115
TextEmbeddingProcessor.TYPE,
116-
new TextEmbeddingProcessorFactory(clientAccessor, parameters.env),
116+
new TextEmbeddingProcessorFactory(clientAccessor, parameters.env, parameters.ingestService.getClusterService()),
117117
SparseEncodingProcessor.TYPE,
118-
new SparseEncodingProcessorFactory(clientAccessor, parameters.env),
118+
new SparseEncodingProcessorFactory(clientAccessor, parameters.env, parameters.ingestService.getClusterService()),
119119
TextImageEmbeddingProcessor.TYPE,
120120
new TextImageEmbeddingProcessorFactory(clientAccessor, parameters.env, parameters.ingestService.getClusterService()),
121121
TextChunkingProcessor.TYPE,

src/main/java/org/opensearch/neuralsearch/processor/InferenceProcessor.java

+22-57
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
import java.util.Objects;
1616
import java.util.function.BiConsumer;
1717
import java.util.function.Consumer;
18-
import java.util.function.Supplier;
1918
import java.util.stream.Collectors;
2019
import java.util.stream.IntStream;
2120

@@ -24,8 +23,9 @@
2423
import org.apache.commons.lang3.StringUtils;
2524
import org.opensearch.common.collect.Tuple;
2625
import org.opensearch.core.common.util.CollectionUtils;
26+
import org.opensearch.cluster.service.ClusterService;
2727
import org.opensearch.env.Environment;
28-
import org.opensearch.index.mapper.MapperService;
28+
import org.opensearch.index.mapper.IndexFieldMapper;
2929
import org.opensearch.ingest.AbstractProcessor;
3030
import org.opensearch.ingest.IngestDocument;
3131
import org.opensearch.ingest.IngestDocumentWrapper;
@@ -35,6 +35,7 @@
3535
import com.google.common.collect.ImmutableMap;
3636

3737
import lombok.extern.log4j.Log4j2;
38+
import org.opensearch.neuralsearch.util.ProcessorDocumentUtils;
3839

3940
/**
4041
* The abstract class for text processing use cases. Users provide a field name map and a model id.
@@ -60,6 +61,7 @@ public abstract class InferenceProcessor extends AbstractProcessor {
6061
protected final MLCommonsClientAccessor mlCommonsClientAccessor;
6162

6263
private final Environment environment;
64+
private final ClusterService clusterService;
6365

6466
public InferenceProcessor(
6567
String tag,
@@ -69,18 +71,19 @@ public InferenceProcessor(
6971
String modelId,
7072
Map<String, Object> fieldMap,
7173
MLCommonsClientAccessor clientAccessor,
72-
Environment environment
74+
Environment environment,
75+
ClusterService clusterService
7376
) {
7477
super(tag, description);
7578
this.type = type;
7679
if (StringUtils.isBlank(modelId)) throw new IllegalArgumentException("model_id is null or empty, cannot process it");
7780
validateEmbeddingConfiguration(fieldMap);
78-
7981
this.listTypeNestedMapKey = listTypeNestedMapKey;
8082
this.modelId = modelId;
8183
this.fieldMap = fieldMap;
8284
this.mlCommonsClientAccessor = clientAccessor;
8385
this.environment = environment;
86+
this.clusterService = clusterService;
8487
}
8588

8689
private void validateEmbeddingConfiguration(Map<String, Object> fieldMap) {
@@ -117,12 +120,12 @@ public IngestDocument execute(IngestDocument ingestDocument) throws Exception {
117120
public void execute(IngestDocument ingestDocument, BiConsumer<IngestDocument, Exception> handler) {
118121
try {
119122
validateEmbeddingFieldsValue(ingestDocument);
120-
Map<String, Object> ProcessMap = buildMapWithProcessorKeyAndOriginalValue(ingestDocument);
121-
List<String> inferenceList = createInferenceList(ProcessMap);
123+
Map<String, Object> processMap = buildMapWithTargetKeyAndOriginalValue(ingestDocument);
124+
List<String> inferenceList = createInferenceList(processMap);
122125
if (inferenceList.size() == 0) {
123126
handler.accept(ingestDocument, null);
124127
} else {
125-
doExecute(ingestDocument, ProcessMap, inferenceList, handler);
128+
doExecute(ingestDocument, processMap, inferenceList, handler);
126129
}
127130
} catch (Exception e) {
128131
handler.accept(null, e);
@@ -225,7 +228,7 @@ private List<DataForInference> getDataForInference(List<IngestDocumentWrapper> i
225228
List<String> inferenceList = null;
226229
try {
227230
validateEmbeddingFieldsValue(ingestDocumentWrapper.getIngestDocument());
228-
processMap = buildMapWithProcessorKeyAndOriginalValue(ingestDocumentWrapper.getIngestDocument());
231+
processMap = buildMapWithTargetKeyAndOriginalValue(ingestDocumentWrapper.getIngestDocument());
229232
inferenceList = createInferenceList(processMap);
230233
} catch (Exception e) {
231234
ingestDocumentWrapper.update(ingestDocumentWrapper.getIngestDocument(), e);
@@ -273,7 +276,7 @@ private void createInferenceListForMapTypeInput(Object sourceValue, List<String>
273276
}
274277

275278
@VisibleForTesting
276-
Map<String, Object> buildMapWithProcessorKeyAndOriginalValue(IngestDocument ingestDocument) {
279+
Map<String, Object> buildMapWithTargetKeyAndOriginalValue(IngestDocument ingestDocument) {
277280
Map<String, Object> sourceAndMetadataMap = ingestDocument.getSourceAndMetadata();
278281
Map<String, Object> mapWithProcessorKeys = new LinkedHashMap<>();
279282
for (Map.Entry<String, Object> fieldMapEntry : fieldMap.entrySet()) {
@@ -331,54 +334,16 @@ private void buildMapWithProcessorKeyAndOriginalValueForMapType(
331334

332335
private void validateEmbeddingFieldsValue(IngestDocument ingestDocument) {
333336
Map<String, Object> sourceAndMetadataMap = ingestDocument.getSourceAndMetadata();
334-
for (Map.Entry<String, Object> embeddingFieldsEntry : fieldMap.entrySet()) {
335-
Object sourceValue = sourceAndMetadataMap.get(embeddingFieldsEntry.getKey());
336-
if (sourceValue != null) {
337-
String sourceKey = embeddingFieldsEntry.getKey();
338-
Class<?> sourceValueClass = sourceValue.getClass();
339-
if (List.class.isAssignableFrom(sourceValueClass) || Map.class.isAssignableFrom(sourceValueClass)) {
340-
validateNestedTypeValue(sourceKey, sourceValue, () -> 1);
341-
} else if (!String.class.isAssignableFrom(sourceValueClass)) {
342-
throw new IllegalArgumentException("field [" + sourceKey + "] is neither string nor nested type, cannot process it");
343-
} else if (StringUtils.isBlank(sourceValue.toString())) {
344-
throw new IllegalArgumentException("field [" + sourceKey + "] has empty string value, cannot process it");
345-
}
346-
}
347-
}
348-
}
349-
350-
@SuppressWarnings({ "rawtypes", "unchecked" })
351-
private void validateNestedTypeValue(String sourceKey, Object sourceValue, Supplier<Integer> maxDepthSupplier) {
352-
int maxDepth = maxDepthSupplier.get();
353-
if (maxDepth > MapperService.INDEX_MAPPING_DEPTH_LIMIT_SETTING.get(environment.settings())) {
354-
throw new IllegalArgumentException("map type field [" + sourceKey + "] reached max depth limit, cannot process it");
355-
} else if ((List.class.isAssignableFrom(sourceValue.getClass()))) {
356-
validateListTypeValue(sourceKey, sourceValue, maxDepthSupplier);
357-
} else if (Map.class.isAssignableFrom(sourceValue.getClass())) {
358-
((Map) sourceValue).values()
359-
.stream()
360-
.filter(Objects::nonNull)
361-
.forEach(x -> validateNestedTypeValue(sourceKey, x, () -> maxDepth + 1));
362-
} else if (!String.class.isAssignableFrom(sourceValue.getClass())) {
363-
throw new IllegalArgumentException("map type field [" + sourceKey + "] has non-string type, cannot process it");
364-
} else if (StringUtils.isBlank(sourceValue.toString())) {
365-
throw new IllegalArgumentException("map type field [" + sourceKey + "] has empty string, cannot process it");
366-
}
367-
}
368-
369-
@SuppressWarnings({ "rawtypes" })
370-
private void validateListTypeValue(String sourceKey, Object sourceValue, Supplier<Integer> maxDepthSupplier) {
371-
for (Object value : (List) sourceValue) {
372-
if (value instanceof Map) {
373-
validateNestedTypeValue(sourceKey, value, () -> maxDepthSupplier.get() + 1);
374-
} else if (value == null) {
375-
throw new IllegalArgumentException("list type field [" + sourceKey + "] has null, cannot process it");
376-
} else if (!(value instanceof String)) {
377-
throw new IllegalArgumentException("list type field [" + sourceKey + "] has non string value, cannot process it");
378-
} else if (StringUtils.isBlank(value.toString())) {
379-
throw new IllegalArgumentException("list type field [" + sourceKey + "] has empty string, cannot process it");
380-
}
381-
}
337+
String indexName = sourceAndMetadataMap.get(IndexFieldMapper.NAME).toString();
338+
ProcessorDocumentUtils.validateMapTypeValue(
339+
FIELD_MAP_FIELD,
340+
sourceAndMetadataMap,
341+
fieldMap,
342+
indexName,
343+
clusterService,
344+
environment,
345+
false
346+
);
382347
}
383348

384349
protected void setVectorFieldsToDocument(IngestDocument ingestDocument, Map<String, Object> processorMap, List<?> results) {

src/main/java/org/opensearch/neuralsearch/processor/SparseEncodingProcessor.java

+4-2
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import java.util.function.BiConsumer;
1010
import java.util.function.Consumer;
1111

12+
import org.opensearch.cluster.service.ClusterService;
1213
import org.opensearch.core.action.ActionListener;
1314
import org.opensearch.env.Environment;
1415
import org.opensearch.ingest.IngestDocument;
@@ -33,9 +34,10 @@ public SparseEncodingProcessor(
3334
String modelId,
3435
Map<String, Object> fieldMap,
3536
MLCommonsClientAccessor clientAccessor,
36-
Environment environment
37+
Environment environment,
38+
ClusterService clusterService
3739
) {
38-
super(tag, description, TYPE, LIST_TYPE_NESTED_MAP_KEY, modelId, fieldMap, clientAccessor, environment);
40+
super(tag, description, TYPE, LIST_TYPE_NESTED_MAP_KEY, modelId, fieldMap, clientAccessor, environment, clusterService);
3941
}
4042

4143
@Override

src/main/java/org/opensearch/neuralsearch/processor/TextChunkingProcessor.java

+17-55
Original file line numberDiff line numberDiff line change
@@ -17,14 +17,14 @@
1717
import org.opensearch.env.Environment;
1818
import org.opensearch.cluster.service.ClusterService;
1919
import org.opensearch.index.analysis.AnalysisRegistry;
20-
import org.opensearch.index.mapper.MapperService;
2120
import org.opensearch.index.IndexSettings;
2221
import org.opensearch.ingest.AbstractProcessor;
2322
import org.opensearch.ingest.IngestDocument;
2423
import org.opensearch.neuralsearch.processor.chunker.Chunker;
2524
import org.opensearch.index.mapper.IndexFieldMapper;
2625
import org.opensearch.neuralsearch.processor.chunker.ChunkerFactory;
2726
import org.opensearch.neuralsearch.processor.chunker.FixedTokenLengthChunker;
27+
import org.opensearch.neuralsearch.util.ProcessorDocumentUtils;
2828

2929
import static org.opensearch.neuralsearch.processor.chunker.Chunker.MAX_CHUNK_LIMIT_FIELD;
3030
import static org.opensearch.neuralsearch.processor.chunker.Chunker.DEFAULT_MAX_CHUNK_LIMIT;
@@ -164,7 +164,16 @@ private int getMaxTokenCount(final Map<String, Object> sourceAndMetadataMap) {
164164
@Override
165165
public IngestDocument execute(final IngestDocument ingestDocument) {
166166
Map<String, Object> sourceAndMetadataMap = ingestDocument.getSourceAndMetadata();
167-
validateFieldsValue(sourceAndMetadataMap);
167+
String indexName = sourceAndMetadataMap.get(IndexFieldMapper.NAME).toString();
168+
ProcessorDocumentUtils.validateMapTypeValue(
169+
FIELD_MAP_FIELD,
170+
sourceAndMetadataMap,
171+
fieldMap,
172+
indexName,
173+
clusterService,
174+
environment,
175+
true
176+
);
168177
// fixed token length algorithm needs runtime parameter max_token_count for tokenization
169178
Map<String, Object> runtimeParameters = new HashMap<>();
170179
int maxTokenCount = getMaxTokenCount(sourceAndMetadataMap);
@@ -176,59 +185,6 @@ public IngestDocument execute(final IngestDocument ingestDocument) {
176185
return ingestDocument;
177186
}
178187

179-
private void validateFieldsValue(final Map<String, Object> sourceAndMetadataMap) {
180-
for (Map.Entry<String, Object> embeddingFieldsEntry : fieldMap.entrySet()) {
181-
Object sourceValue = sourceAndMetadataMap.get(embeddingFieldsEntry.getKey());
182-
if (Objects.nonNull(sourceValue)) {
183-
String sourceKey = embeddingFieldsEntry.getKey();
184-
if (sourceValue instanceof List || sourceValue instanceof Map) {
185-
validateNestedTypeValue(sourceKey, sourceValue, 1);
186-
} else if (!(sourceValue instanceof String)) {
187-
throw new IllegalArgumentException(
188-
String.format(Locale.ROOT, "field [%s] is neither string nor nested type, cannot process it", sourceKey)
189-
);
190-
}
191-
}
192-
}
193-
}
194-
195-
@SuppressWarnings({ "rawtypes", "unchecked" })
196-
private void validateNestedTypeValue(final String sourceKey, final Object sourceValue, final int maxDepth) {
197-
if (maxDepth > MapperService.INDEX_MAPPING_DEPTH_LIMIT_SETTING.get(environment.settings())) {
198-
throw new IllegalArgumentException(
199-
String.format(Locale.ROOT, "map type field [%s] reached max depth limit, cannot process it", sourceKey)
200-
);
201-
} else if (sourceValue instanceof List) {
202-
validateListTypeValue(sourceKey, sourceValue, maxDepth);
203-
} else if (sourceValue instanceof Map) {
204-
((Map) sourceValue).values()
205-
.stream()
206-
.filter(Objects::nonNull)
207-
.forEach(x -> validateNestedTypeValue(sourceKey, x, maxDepth + 1));
208-
} else if (!(sourceValue instanceof String)) {
209-
throw new IllegalArgumentException(
210-
String.format(Locale.ROOT, "map type field [%s] has non-string type, cannot process it", sourceKey)
211-
);
212-
}
213-
}
214-
215-
@SuppressWarnings({ "rawtypes" })
216-
private void validateListTypeValue(final String sourceKey, final Object sourceValue, final int maxDepth) {
217-
for (Object value : (List) sourceValue) {
218-
if (value instanceof Map) {
219-
validateNestedTypeValue(sourceKey, value, maxDepth + 1);
220-
} else if (value == null) {
221-
throw new IllegalArgumentException(
222-
String.format(Locale.ROOT, "list type field [%s] has null, cannot process it", sourceKey)
223-
);
224-
} else if (!(value instanceof String)) {
225-
throw new IllegalArgumentException(
226-
String.format(Locale.ROOT, "list type field [%s] has non-string value, cannot process it", sourceKey)
227-
);
228-
}
229-
}
230-
}
231-
232188
@SuppressWarnings("unchecked")
233189
private int getChunkStringCountFromMap(Map<String, Object> sourceAndMetadataMap, final Map<String, Object> fieldMap) {
234190
int chunkStringCount = 0;
@@ -334,7 +290,13 @@ private List<String> chunkLeafType(final Object value, final Map<String, Object>
334290
// leaf type means null, String or List<String>
335291
// the result should be an empty list when the input is null
336292
List<String> result = new ArrayList<>();
293+
if (value == null) {
294+
return result;
295+
}
337296
if (value instanceof String) {
297+
if (StringUtils.isBlank(String.valueOf(value))) {
298+
return result;
299+
}
338300
result = chunkString(value.toString(), runTimeParameters);
339301
} else if (isListOfString(value)) {
340302
result = chunkList((List<String>) value, runTimeParameters);

src/main/java/org/opensearch/neuralsearch/processor/TextEmbeddingProcessor.java

+4-2
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import java.util.function.BiConsumer;
1010
import java.util.function.Consumer;
1111

12+
import org.opensearch.cluster.service.ClusterService;
1213
import org.opensearch.core.action.ActionListener;
1314
import org.opensearch.env.Environment;
1415
import org.opensearch.ingest.IngestDocument;
@@ -32,9 +33,10 @@ public TextEmbeddingProcessor(
3233
String modelId,
3334
Map<String, Object> fieldMap,
3435
MLCommonsClientAccessor clientAccessor,
35-
Environment environment
36+
Environment environment,
37+
ClusterService clusterService
3638
) {
37-
super(tag, description, TYPE, LIST_TYPE_NESTED_MAP_KEY, modelId, fieldMap, clientAccessor, environment);
39+
super(tag, description, TYPE, LIST_TYPE_NESTED_MAP_KEY, modelId, fieldMap, clientAccessor, environment, clusterService);
3840
}
3941

4042
@Override

0 commit comments

Comments
 (0)