Skip to content

Commit 9bf5cce

Browse files
committed
address casting issues
Signed-off-by: will-hwang <sang7239@gmail.com>
1 parent 90851b6 commit 9bf5cce

File tree

6 files changed

+67
-47
lines changed

6 files changed

+67
-47
lines changed

src/main/java/org/opensearch/neuralsearch/processor/InferenceProcessor.java

+5-5
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,7 @@ private void validateEmbeddingConfiguration(Map<String, Object> fieldMap) {
119119

120120
public abstract void doExecute(
121121
IngestDocument ingestDocument,
122-
Map<String, Object> ProcessMap,
122+
Map<String, Object> processMap,
123123
List<String> inferenceList,
124124
BiConsumer<IngestDocument, Exception> handler
125125
);
@@ -167,7 +167,7 @@ void preprocessIngestDocument(IngestDocument ingestDocument) {
167167
* @param handler a callback handler to handle inference results which is a list of objects.
168168
* @param onException an exception callback to handle exception.
169169
*/
170-
protected abstract void doBatchExecute(List<String> inferenceList, Consumer<List<?>> handler, Consumer<Exception> onException);
170+
abstract void doBatchExecute(List<String> inferenceList, Consumer<List<?>> handler, Consumer<Exception> onException);
171171

172172
@Override
173173
public void subBatchExecute(List<IngestDocumentWrapper> ingestDocumentWrappers, Consumer<List<IngestDocumentWrapper>> handler) {
@@ -591,21 +591,21 @@ private List<Map<String, Object>> buildNLPResultForListType(List<String> sourceV
591591
* This method invokes inference call through mlCommonsClientAccessor and populates retrieved embeddings to ingestDocument
592592
*
593593
* @param ingestDocument ingestDocument to populate embeddings to
594-
* @param ProcessMap map indicating the path in ingestDocument to populate embeddings
594+
* @param processMap map indicating the path in ingestDocument to populate embeddings
595595
* @param inferenceList list of texts to be model inference
596596
* @param handler SourceAndMetadataMap of ingestDocument Document
597597
*
598598
*/
599599
protected void makeInferenceCall(
600600
IngestDocument ingestDocument,
601-
Map<String, Object> ProcessMap,
601+
Map<String, Object> processMap,
602602
List<String> inferenceList,
603603
BiConsumer<IngestDocument, Exception> handler
604604
) {
605605
mlCommonsClientAccessor.inferenceSentences(
606606
TextInferenceRequest.builder().modelId(this.modelId).inputTexts(inferenceList).build(),
607607
ActionListener.wrap(vectors -> {
608-
setVectorFieldsToDocument(ingestDocument, ProcessMap, vectors);
608+
setVectorFieldsToDocument(ingestDocument, processMap, vectors);
609609
handler.accept(ingestDocument, null);
610610
}, e -> { handler.accept(null, e); })
611611
);

src/main/java/org/opensearch/neuralsearch/processor/TextEmbeddingProcessor.java

+12-17
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ public final class TextEmbeddingProcessor extends InferenceProcessor {
3333
public static final String TYPE = "text_embedding";
3434
public static final String LIST_TYPE_NESTED_MAP_KEY = "knn";
3535
public static final String SKIP_EXISTING = "skip_existing";
36-
public static final boolean DEFAULT_SKIP_EXISTING = Boolean.FALSE;
36+
public static final boolean DEFAULT_SKIP_EXISTING = false;
3737
private static final String INDEX_FIELD = "_index";
3838
private static final String ID_FIELD = "_id";
3939
private final OpenSearchClient openSearchClient;
@@ -62,44 +62,39 @@ public TextEmbeddingProcessor(
6262
@Override
6363
public void doExecute(
6464
IngestDocument ingestDocument,
65-
Map<String, Object> ProcessMap,
65+
Map<String, Object> processMap,
6666
List<String> inferenceList,
6767
BiConsumer<IngestDocument, Exception> handler
6868
) {
69-
if (skipExisting) { // if skipExisting flag is turned on, eligible inference texts will be compared and filtered after embeddings
70-
// have been copied
69+
if (skipExisting) {
70+
// if skipExisting flag is turned on, eligible inference texts will be compared and filtered after embeddings are copied
7171
String index = ingestDocument.getSourceAndMetadata().get(INDEX_FIELD).toString();
7272
String id = ingestDocument.getSourceAndMetadata().get(ID_FIELD).toString();
7373
openSearchClient.execute(GetAction.INSTANCE, new GetRequest(index, id), ActionListener.wrap(response -> {
7474
final Map<String, Object> existingDocument = response.getSourceAsMap();
7575
if (existingDocument == null || existingDocument.isEmpty()) {
76-
makeInferenceCall(ingestDocument, ProcessMap, inferenceList, handler);
76+
makeInferenceCall(ingestDocument, processMap, inferenceList, handler);
7777
} else {
7878
// filter given ProcessMap by comparing existing document with ingestDocument
7979
Map<String, Object> filteredProcessMap = textEmbeddingInferenceFilter.filter(
8080
existingDocument,
8181
ingestDocument.getSourceAndMetadata(),
82-
ProcessMap
82+
processMap
8383
);
8484
// create inference list based on filtered ProcessMap
8585
List<String> filteredInferenceList = createInferenceList(filteredProcessMap).stream()
8686
.filter(Objects::nonNull)
8787
.collect(Collectors.toList());
88-
if (!filteredInferenceList.isEmpty()) {
89-
makeInferenceCall(ingestDocument, filteredProcessMap, filteredInferenceList, handler);
90-
} else {
88+
if (filteredInferenceList.isEmpty()) {
9189
handler.accept(ingestDocument, null);
90+
} else {
91+
makeInferenceCall(ingestDocument, filteredProcessMap, filteredInferenceList, handler);
9292
}
9393
}
9494
}, e -> { handler.accept(null, e); }));
95-
} else { // skip existing flag is turned off. Call model inference without filtering
96-
mlCommonsClientAccessor.inferenceSentences(
97-
TextInferenceRequest.builder().modelId(this.modelId).inputTexts(inferenceList).build(),
98-
ActionListener.wrap(vectors -> {
99-
setVectorFieldsToDocument(ingestDocument, ProcessMap, vectors);
100-
handler.accept(ingestDocument, null);
101-
}, e -> { handler.accept(null, e); })
102-
);
95+
} else {
96+
// skip existing flag is turned off. Call model inference without filtering
97+
makeInferenceCall(ingestDocument, processMap, inferenceList, handler);
10398
}
10499
}
105100

src/main/java/org/opensearch/neuralsearch/processor/optimization/InferenceFilter.java

+5-10
Original file line numberDiff line numberDiff line change
@@ -92,9 +92,7 @@ public abstract List<Object> filterInferenceValuesInList(
9292
* @param existingSourceAndMetadataMap The metadata map of the existing document.
9393
* @param sourceAndMetadataMap The metadata map of the new document.
9494
* @param processMap The current map being processed.
95-
*
9695
* @return A filtered map containing only elements that require new embeddings.
97-
*
9896
*/
9997
public Map<String, Object> filter(
10098
Map<String, Object> existingSourceAndMetadataMap,
@@ -129,11 +127,8 @@ private Map<String, Object> filter(
129127
}
130128
Map<String, Object> filteredProcessMap = new HashMap<>();
131129
Map<String, Object> castedProcessMap = ProcessorUtils.castToMap(processMap);
132-
for (Map.Entry<?, ?> entry : castedProcessMap.entrySet()) {
133-
if ((entry.getKey() instanceof String) == false) {
134-
throw new IllegalArgumentException("key for processMap must be a string");
135-
}
136-
String key = (String) entry.getKey();
130+
for (Map.Entry<String, Object> entry : castedProcessMap.entrySet()) {
131+
String key = entry.getKey();
137132
Object value = entry.getValue();
138133
String currentPath = traversedPath.isEmpty() ? key : traversedPath + "." + key;
139134
if (value instanceof Map<?, ?>) {
@@ -142,7 +137,7 @@ private Map<String, Object> filter(
142137
} else if (value instanceof List) {
143138
List<Object> processedList = filterListValue(
144139
currentPath,
145-
(List<Object>) value,
140+
ProcessorUtils.castToObjectList(value),
146141
sourceAndMetadataMap,
147142
existingSourceAndMetadataMap
148143
);
@@ -192,8 +187,8 @@ protected List<Object> filterListValue(
192187
// return empty list if processList and existingList are equal and embeddings are copied, return empty list otherwise
193188
return filterInferenceValuesInList(
194189
processList,
195-
(List<Object>) existingList.get(),
196-
(List<Object>) embeddingList.get(),
190+
ProcessorUtils.castToObjectList(existingList.get()),
191+
ProcessorUtils.castToObjectList(embeddingList.get()),
197192
sourceAndMetadataMap,
198193
existingSourceAndMetadataMap,
199194
embeddingKey

src/main/java/org/opensearch/neuralsearch/processor/optimization/TextEmbeddingInferenceFilter.java

+10-6
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import java.util.Collections;
1111
import java.util.List;
1212
import java.util.Map;
13+
import java.util.Objects;
1314
import java.util.Optional;
1415

1516
/**
@@ -43,23 +44,24 @@ public Object filterInferenceValue(
4344
int index
4445
) {
4546
String textPath = reversedFieldMap.get(embeddingPath);
46-
if (textPath == null) {
47+
if (Objects.isNull(textPath)) {
4748
return processValue;
4849
}
4950
Optional<Object> existingValue = ProcessorUtils.getValueFromSource(existingSourceAndMetadataMap, textPath, index);
5051
Optional<Object> embeddingValue = ProcessorUtils.getValueFromSource(existingSourceAndMetadataMap, embeddingPath, index);
5152

5253
if (existingValue.isPresent() && embeddingValue.isPresent() && existingValue.get().equals(processValue)) {
5354
ProcessorUtils.setValueToSource(sourceAndMetadataMap, embeddingPath, embeddingValue.get(), index);
54-
return null; // if successfully copied, return null to be filtered out from process map
55+
// if successfully copied, return null to be filtered out from process map
56+
return null;
5557
}
56-
return processValue; // processValue and existingValue are different, return processValue to be included in process map
58+
// processValue and existingValue are different, return processValue to be included in process map
59+
return processValue;
5760
}
5861

5962
/**
6063
* Filters List value by checking if the texts in list are identical in both the existing and new document.
6164
* If lists are equal, the corresponding embeddings are copied
62-
*
6365
* @return empty list if embeddings are reused; the original list otherwise.
6466
*/
6567
@Override
@@ -73,8 +75,10 @@ public List<Object> filterInferenceValuesInList(
7375
) {
7476
if (processList.equals(existingList)) {
7577
ProcessorUtils.setValueToSource(sourceAndMetadataMap, fullEmbeddingKey, embeddingList);
76-
return Collections.emptyList(); // if successfully copied, return empty list to be filtered out from process map
78+
// if successfully copied, return empty list to be filtered out from process map
79+
return Collections.emptyList();
7780
}
78-
return processList; // source list and existing list are different, return processList to be included in process map
81+
// source list and existing list are different, return processList to be included in process map
82+
return processList;
7983
}
8084
}

src/main/java/org/opensearch/neuralsearch/processor/util/ProcessorUtils.java

+16-9
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
import java.util.HashMap;
1313
import java.util.List;
1414
import java.util.Map;
15+
import java.util.Objects;
1516
import java.util.Optional;
1617
import java.util.Stack;
1718

@@ -117,7 +118,7 @@ public static void removeTargetFieldFromSource(final Map<String, Object> sourceA
117118
if (key.equals(lastKey)) {
118119
break;
119120
}
120-
currentMap = (Map<String, Object>) currentMap.get(key);
121+
currentMap = castToMap(currentMap.get(key));
121122
}
122123

123124
// Remove the last key this is guaranteed
@@ -132,8 +133,7 @@ public static void removeTargetFieldFromSource(final Map<String, Object> sourceA
132133
parentMap = currentParentMapWithChild.v1();
133134
key = currentParentMapWithChild.v2();
134135

135-
@SuppressWarnings("unchecked")
136-
Map<String, Object> innerMap = (Map<String, Object>) parentMap.get(key);
136+
Map<String, Object> innerMap = castToMap(parentMap.get(key));
137137

138138
if (innerMap != null && innerMap.isEmpty()) {
139139
parentMap.remove(key);
@@ -167,13 +167,13 @@ public static Optional<Object> getValueFromSource(final Map<String, Object> sour
167167
for (String key : keys) {
168168
currentValue = currentValue.flatMap(value -> {
169169
if (value instanceof ArrayList<?> && index != -1) {
170-
Object listValue = ((ArrayList) value).get(index);
170+
Object listValue = (castToObjectList(value)).get(index);
171171
if (listValue instanceof Map) {
172-
Map<String, Object> currentMap = (Map<String, Object>) listValue;
172+
Map<String, Object> currentMap = castToMap(listValue);
173173
return Optional.ofNullable(currentMap.get(key));
174174
}
175175
} else if (value instanceof Map<?, ?>) {
176-
Map<String, Object> currentMap = (Map<String, Object>) value;
176+
Map<String, Object> currentMap = castToMap(value);
177177
return Optional.ofNullable(currentMap.get(key));
178178
}
179179
return Optional.empty();
@@ -207,7 +207,7 @@ public static void setValueToSource(Map<String, Object> sourceAsMap, String targ
207207
*/
208208

209209
public static void setValueToSource(Map<String, Object> sourceAsMap, String targetKey, Object targetValue, int index) {
210-
if (sourceAsMap == null || targetKey == null) return;
210+
if (Objects.isNull(sourceAsMap) || Objects.isNull(targetKey)) return;
211211

212212
String[] keys = targetKey.split("\\.");
213213
Map<String, Object> current = sourceAsMap;
@@ -217,10 +217,10 @@ public static void setValueToSource(Map<String, Object> sourceAsMap, String targ
217217
if (next instanceof ArrayList<?> list) {
218218
if (index < 0 || index >= list.size()) return;
219219
if (list.get(index) instanceof Map) {
220-
current = (Map<String, Object>) list.get(index);
220+
current = castToMap(list.get(index));
221221
}
222222
} else if (next instanceof Map) {
223-
current = (Map<String, Object>) next;
223+
current = castToMap(next);
224224
} else {
225225
throw new IllegalStateException("Unexpected data structure at " + keys[i]);
226226
}
@@ -274,4 +274,11 @@ public static boolean isNumeric(Object value) {
274274
public static Map<String, Object> castToMap(Object obj) {
275275
return (Map<String, Object>) obj;
276276
}
277+
278+
// This method should be used only when you are certain the object is a `List<Object>`.
279+
// It is recommended to use this method as a last resort.
280+
@SuppressWarnings("unchecked")
281+
public static List<Object> castToObjectList(Object obj) {
282+
return (List<Object>) obj;
283+
}
277284
}

src/test/java/org/opensearch/neuralsearch/util/ProcessorDocumentUtilsTests.java

+19
Original file line numberDiff line numberDiff line change
@@ -271,4 +271,23 @@ public void testFlattenAndFlip_withMultipleLevelsWithNestedMaps_thenSuccess() {
271271
Map<String, String> actual = ProcessorDocumentUtils.flattenAndFlip(nestedMap);
272272
assertEquals(expected, actual);
273273
}
274+
275+
public void testUnflatten_withListOfObject_thenSuccess() {
276+
Map<String, Object> map1 = Map.of("b.c", "d", "f", "h");
277+
Map<String, Object> map2 = Map.of("b.c", "e", "f", "i");
278+
List<Map<String, Object>> list = Arrays.asList(map1, map2);
279+
Map<String, Object> input = Map.of("a", list);
280+
281+
Map<String, Object> nestedB1 = Map.of("c", "d");
282+
Map<String, Object> expectedMap1 = Map.of("b", nestedB1, "f", "h");
283+
Map<String, Object> nestedB2 = Map.of("c", "e");
284+
Map<String, Object> expectedMap2 = Map.of("b", nestedB2, "f", "i");
285+
286+
List<Map<String, Object>> expectedList = Arrays.asList(expectedMap1, expectedMap2);
287+
288+
Map<String, Object> expected = Map.of("a", expectedList);
289+
290+
Map<String, Object> result = ProcessorDocumentUtils.unflattenJson(input);
291+
assertEquals(expected, result);
292+
}
274293
}

0 commit comments

Comments
 (0)