6
6
7
7
import java .util .List ;
8
8
import java .util .Map ;
9
+ import java .util .Objects ;
9
10
import java .util .function .BiConsumer ;
10
11
import java .util .function .Consumer ;
12
+ import java .util .stream .Collectors ;
11
13
14
+ import org .opensearch .action .get .GetAction ;
15
+ import org .opensearch .action .get .GetRequest ;
12
16
import org .opensearch .cluster .service .ClusterService ;
13
17
import org .opensearch .core .action .ActionListener ;
14
18
import org .opensearch .env .Environment ;
15
19
import org .opensearch .ingest .IngestDocument ;
16
20
import org .opensearch .neuralsearch .ml .MLCommonsClientAccessor ;
17
21
18
22
import lombok .extern .log4j .Log4j2 ;
23
+ import org .opensearch .neuralsearch .processor .optimization .TextEmbeddingInferenceFilter ;
24
+ import org .opensearch .transport .client .OpenSearchClient ;
19
25
20
26
/**
21
27
* This processor is used for user input data text embedding processing, model_id can be used to indicate which model user use,
@@ -26,34 +32,71 @@ public final class TextEmbeddingProcessor extends InferenceProcessor {
26
32
27
33
public static final String TYPE = "text_embedding" ;
28
34
public static final String LIST_TYPE_NESTED_MAP_KEY = "knn" ;
35
+ public static final String SKIP_EXISTING = "skip_existing" ;
36
+ public static final boolean DEFAULT_SKIP_EXISTING = false ;
37
+ private static final String INDEX_FIELD = "_index" ;
38
+ private static final String ID_FIELD = "_id" ;
39
+ private final OpenSearchClient openSearchClient ;
40
+ private final boolean skipExisting ;
41
+ private final TextEmbeddingInferenceFilter textEmbeddingInferenceFilter ;
29
42
30
43
public TextEmbeddingProcessor (
31
44
String tag ,
32
45
String description ,
33
46
int batchSize ,
34
47
String modelId ,
35
48
Map <String , Object > fieldMap ,
49
+ boolean skipExisting ,
50
+ TextEmbeddingInferenceFilter textEmbeddingInferenceFilter ,
51
+ OpenSearchClient openSearchClient ,
36
52
MLCommonsClientAccessor clientAccessor ,
37
53
Environment environment ,
38
54
ClusterService clusterService
39
55
) {
40
56
super (tag , description , batchSize , TYPE , LIST_TYPE_NESTED_MAP_KEY , modelId , fieldMap , clientAccessor , environment , clusterService );
57
+ this .skipExisting = skipExisting ;
58
+ this .textEmbeddingInferenceFilter = textEmbeddingInferenceFilter ;
59
+ this .openSearchClient = openSearchClient ;
41
60
}
42
61
43
62
@ Override
44
63
public void doExecute (
45
64
IngestDocument ingestDocument ,
46
- Map <String , Object > ProcessMap ,
65
+ Map <String , Object > processMap ,
47
66
List <String > inferenceList ,
48
67
BiConsumer <IngestDocument , Exception > handler
49
68
) {
50
- mlCommonsClientAccessor .inferenceSentences (
51
- TextInferenceRequest .builder ().modelId (this .modelId ).inputTexts (inferenceList ).build (),
52
- ActionListener .wrap (vectors -> {
53
- setVectorFieldsToDocument (ingestDocument , ProcessMap , vectors );
69
+ // skip existing flag is turned off. Call model inference without filtering
70
+ if (skipExisting == false ) {
71
+ makeInferenceCall (ingestDocument , processMap , inferenceList , handler );
72
+ return ;
73
+ }
74
+ // if skipExisting flag is turned on, eligible inference texts will be compared and filtered after embeddings are copied
75
+ String index = ingestDocument .getSourceAndMetadata ().get (INDEX_FIELD ).toString ();
76
+ String id = ingestDocument .getSourceAndMetadata ().get (ID_FIELD ).toString ();
77
+ openSearchClient .execute (GetAction .INSTANCE , new GetRequest (index , id ), ActionListener .wrap (response -> {
78
+ final Map <String , Object > existingDocument = response .getSourceAsMap ();
79
+ if (existingDocument == null || existingDocument .isEmpty ()) {
80
+ makeInferenceCall (ingestDocument , processMap , inferenceList , handler );
81
+ return ;
82
+ }
83
+ // filter given ProcessMap by comparing existing document with ingestDocument
84
+ Map <String , Object > filteredProcessMap = textEmbeddingInferenceFilter .filter (
85
+ existingDocument ,
86
+ ingestDocument .getSourceAndMetadata (),
87
+ processMap
88
+ );
89
+ // create inference list based on filtered ProcessMap
90
+ List <String > filteredInferenceList = createInferenceList (filteredProcessMap ).stream ()
91
+ .filter (Objects ::nonNull )
92
+ .collect (Collectors .toList ());
93
+ if (filteredInferenceList .isEmpty ()) {
54
94
handler .accept (ingestDocument , null );
55
- }, e -> { handler .accept (null , e ); })
56
- );
95
+ } else {
96
+ makeInferenceCall (ingestDocument , filteredProcessMap , filteredInferenceList , handler );
97
+ }
98
+
99
+ }, e -> { handler .accept (null , e ); }));
57
100
}
58
101
59
102
@ Override
0 commit comments