26
26
import org .apache .commons .lang3 .StringUtils ;
27
27
import org .apache .commons .lang3 .tuple .ImmutablePair ;
28
28
import org .apache .commons .lang3 .tuple .Pair ;
29
+ import org .opensearch .action .get .MultiGetItemResponse ;
30
+ import org .opensearch .action .get .MultiGetRequest ;
29
31
import org .opensearch .common .collect .Tuple ;
30
32
import org .opensearch .core .action .ActionListener ;
31
33
import org .opensearch .core .common .util .CollectionUtils ;
@@ -54,6 +56,8 @@ public abstract class InferenceProcessor extends AbstractBatchingProcessor {
54
56
55
57
public static final String MODEL_ID_FIELD = "model_id" ;
56
58
public static final String FIELD_MAP_FIELD = "field_map" ;
59
+ public static final String INDEX_FIELD = "_index" ;
60
+ public static final String ID_FIELD = "_id" ;
57
61
private static final BiFunction <Object , Object , Object > REMAPPING_FUNCTION = (v1 , v2 ) -> {
58
62
if (v1 instanceof Collection && v2 instanceof Collection ) {
59
63
((Collection ) v1 ).addAll ((Collection ) v2 );
@@ -182,6 +186,15 @@ public void subBatchExecute(List<IngestDocumentWrapper> ingestDocumentWrappers,
182
186
handler .accept (ingestDocumentWrappers );
183
187
return ;
184
188
}
189
+ doSubBatchExecute (ingestDocumentWrappers , inferenceList , dataForInferences , handler );
190
+ }
191
+
192
+ protected void doSubBatchExecute (
193
+ List <IngestDocumentWrapper > ingestDocumentWrappers ,
194
+ List <String > inferenceList ,
195
+ List <DataForInference > dataForInferences ,
196
+ Consumer <List <IngestDocumentWrapper >> handler
197
+ ) {
185
198
Tuple <List <String >, Map <Integer , Integer >> sortedResult = sortByLengthAndReturnOriginalOrder (inferenceList );
186
199
inferenceList = sortedResult .v1 ();
187
200
Map <Integer , Integer > originalOrder = sortedResult .v2 ();
@@ -238,7 +251,7 @@ private List<?> restoreToOriginalOrder(List<?> results, Map<Integer, Integer> or
238
251
return sortedResults ;
239
252
}
240
253
241
- private List <String > constructInferenceTexts (List <DataForInference > dataForInferences ) {
254
+ protected List <String > constructInferenceTexts (List <DataForInference > dataForInferences ) {
242
255
List <String > inferenceTexts = new ArrayList <>();
243
256
for (DataForInference dataForInference : dataForInferences ) {
244
257
if (dataForInference .getIngestDocumentWrapper ().getException () != null
@@ -250,7 +263,7 @@ private List<String> constructInferenceTexts(List<DataForInference> dataForInfer
250
263
return inferenceTexts ;
251
264
}
252
265
253
- private List <DataForInference > getDataForInference (List <IngestDocumentWrapper > ingestDocumentWrappers ) {
266
+ protected List <DataForInference > getDataForInference (List <IngestDocumentWrapper > ingestDocumentWrappers ) {
254
267
List <DataForInference > dataForInferences = new ArrayList <>();
255
268
for (IngestDocumentWrapper ingestDocumentWrapper : ingestDocumentWrappers ) {
256
269
Map <String , Object > processMap = null ;
@@ -272,7 +285,7 @@ private List<DataForInference> getDataForInference(List<IngestDocumentWrapper> i
272
285
273
286
@ Getter
274
287
@ AllArgsConstructor
275
- private static class DataForInference {
288
+ protected static class DataForInference {
276
289
private final IngestDocumentWrapper ingestDocumentWrapper ;
277
290
private final Map <String , Object > processMap ;
278
291
private final List <String > inferenceList ;
@@ -415,6 +428,36 @@ protected void setVectorFieldsToDocument(IngestDocument ingestDocument, Map<Stri
415
428
nlpResult .forEach (ingestDocument ::setFieldValue );
416
429
}
417
430
431
+ /**
432
+ * This method creates a MultiGetRequest from a list of ingest documents to be fetched for comparison
433
+ * @param ingestDocumentWrappers, list of ingest documents
434
+ * */
435
+ protected MultiGetRequest buildMultiGetRequest (List <IngestDocumentWrapper > ingestDocumentWrappers ) {
436
+ MultiGetRequest multiGetRequest = new MultiGetRequest ();
437
+ for (IngestDocumentWrapper ingestDocumentWrapper : ingestDocumentWrappers ) {
438
+ Object index = ingestDocumentWrapper .getIngestDocument ().getSourceAndMetadata ().get (INDEX_FIELD );
439
+ Object id = ingestDocumentWrapper .getIngestDocument ().getSourceAndMetadata ().get (ID_FIELD );
440
+ if (Objects .nonNull (index ) && Objects .nonNull (id )) {
441
+ multiGetRequest .add (index .toString (), id .toString ());
442
+ }
443
+ }
444
+ return multiGetRequest ;
445
+ }
446
+
447
+ /**
448
+ * This method creates a map of documents from MultiGetItemResponse where the key is document ID and value is corresponding document
449
+ * @param multiGetItemResponses, array of responses from Multi Get Request
450
+ * */
451
+ protected Map <String , Map <String , Object >> createDocumentMap (MultiGetItemResponse [] multiGetItemResponses ) {
452
+ Map <String , Map <String , Object >> existingDocuments = new HashMap <>();
453
+ for (MultiGetItemResponse item : multiGetItemResponses ) {
454
+ String id = item .getId ();
455
+ Map <String , Object > existingDocument = item .getResponse ().getSourceAsMap ();
456
+ existingDocuments .put (id , existingDocument );
457
+ }
458
+ return existingDocuments ;
459
+ }
460
+
418
461
@ SuppressWarnings ({ "unchecked" })
419
462
@ VisibleForTesting
420
463
Map <String , Object > buildNLPResult (Map <String , Object > processorMap , List <?> results , Map <String , Object > sourceAndMetadataMap ) {
@@ -582,7 +625,7 @@ private List<Map<String, Object>> buildNLPResultForListType(List<String> sourceV
582
625
List <Map <String , Object >> keyToResult = new ArrayList <>();
583
626
sourceValue .stream ()
584
627
.filter (Objects ::nonNull ) // explicit null check is required since sourceValue can contain null values in cases where
585
- // sourceValue has been filtered
628
+ // sourceValue has been filtered
586
629
.forEachOrdered (x -> keyToResult .add (ImmutableMap .of (listTypeNestedMapKey , results .get (indexWrapper .index ++))));
587
630
return keyToResult ;
588
631
}
0 commit comments