9
9
import lombok .Getter ;
10
10
import lombok .Setter ;
11
11
import org .opensearch .action .search .SearchRequest ;
12
- import org .opensearch .common .collect .Tuple ;
13
12
import org .opensearch .index .query .BoolQueryBuilder ;
14
13
import org .opensearch .index .query .QueryBuilder ;
15
14
import org .opensearch .ingest .ConfigurationUtils ;
16
15
import org .opensearch .neuralsearch .query .NeuralSparseQueryBuilder ;
16
+ import org .opensearch .neuralsearch .util .prune .PruneType ;
17
+ import org .opensearch .neuralsearch .util .prune .PruneUtils ;
17
18
import org .opensearch .search .builder .SearchSourceBuilder ;
18
19
import org .opensearch .search .pipeline .AbstractProcessor ;
19
20
import org .opensearch .search .pipeline .Processor ;
20
21
import org .opensearch .search .pipeline .SearchRequestProcessor ;
21
22
import org .opensearch .search .rescore .QueryRescorerBuilder ;
22
23
import org .opensearch .search .rescore .RescorerBuilder ;
23
24
24
- import java .util .Collections ;
25
25
import java .util .Locale ;
26
26
import java .util .Map ;
27
27
import java .util .Objects ;
28
- import java .util .stream .Collectors ;
29
28
30
29
/**
31
30
* A SearchRequestProcessor to generate two-phase NeuralSparseQueryBuilder,
@@ -37,41 +36,37 @@ public class NeuralSparseTwoPhaseProcessor extends AbstractProcessor implements
37
36
38
37
public static final String TYPE = "neural_sparse_two_phase_processor" ;
39
38
private boolean enabled ;
40
- private float ratio ;
39
+ private float pruneRatio ;
40
+ private PruneType pruneType ;
41
41
private float windowExpansion ;
42
42
private int maxWindowSize ;
43
43
private static final String PARAMETER_KEY = "two_phase_parameter" ;
44
- private static final String RATIO_KEY = "prune_ratio" ;
45
44
private static final String ENABLE_KEY = "enabled" ;
46
45
private static final String EXPANSION_KEY = "expansion_rate" ;
47
46
private static final String MAX_WINDOW_SIZE_KEY = "max_window_size" ;
48
47
private static final boolean DEFAULT_ENABLED = true ;
49
48
private static final float DEFAULT_RATIO = 0.4f ;
49
+ private static final PruneType DEFAULT_PRUNE_TYPE = PruneType .MAX_RATIO ;
50
50
private static final float DEFAULT_WINDOW_EXPANSION = 5.0f ;
51
51
private static final int DEFAULT_MAX_WINDOW_SIZE = 10000 ;
52
52
private static final int DEFAULT_BASE_QUERY_SIZE = 10 ;
53
53
private static final int MAX_WINDOWS_SIZE_LOWER_BOUND = 50 ;
54
54
private static final float WINDOW_EXPANSION_LOWER_BOUND = 1.0f ;
55
- private static final float RATIO_LOWER_BOUND = 0f ;
56
- private static final float RATIO_UPPER_BOUND = 1f ;
57
55
58
56
protected NeuralSparseTwoPhaseProcessor (
59
57
String tag ,
60
58
String description ,
61
59
boolean ignoreFailure ,
62
60
boolean enabled ,
63
- float ratio ,
61
+ float pruneRatio ,
62
+ PruneType pruneType ,
64
63
float windowExpansion ,
65
64
int maxWindowSize
66
65
) {
67
66
super (tag , description , ignoreFailure );
68
67
this .enabled = enabled ;
69
- if (ratio < RATIO_LOWER_BOUND || ratio > RATIO_UPPER_BOUND ) {
70
- throw new IllegalArgumentException (
71
- String .format (Locale .ROOT , "The two_phase_parameter.prune_ratio must be within [0, 1]. Received: %f" , ratio )
72
- );
73
- }
74
- this .ratio = ratio ;
68
+ this .pruneRatio = pruneRatio ;
69
+ this .pruneType = pruneType ;
75
70
if (windowExpansion < WINDOW_EXPANSION_LOWER_BOUND ) {
76
71
throw new IllegalArgumentException (
77
72
String .format (Locale .ROOT , "The two_phase_parameter.expansion_rate must >= 1.0. Received: %f" , windowExpansion )
@@ -93,7 +88,7 @@ protected NeuralSparseTwoPhaseProcessor(
93
88
*/
94
89
@ Override
95
90
public SearchRequest processRequest (final SearchRequest request ) {
96
- if (!enabled || ratio == 0f ) {
91
+ if (!enabled || pruneRatio == 0f ) {
97
92
return request ;
98
93
}
99
94
QueryBuilder queryBuilder = request .source ().query ();
@@ -117,43 +112,6 @@ public String getType() {
117
112
return TYPE ;
118
113
}
119
114
120
- /**
121
- * Based on ratio, split a Map into two map by the value.
122
- *
123
- * @param queryTokens the queryTokens map, key is the token String, value is the score.
124
- * @param thresholdRatio The ratio that control how tokens map be split.
125
- * @return A tuple has two element, { token map whose value above threshold, token map whose value below threshold }
126
- */
127
- public static Tuple <Map <String , Float >, Map <String , Float >> splitQueryTokensByRatioedMaxScoreAsThreshold (
128
- final Map <String , Float > queryTokens ,
129
- final float thresholdRatio
130
- ) {
131
- if (Objects .isNull (queryTokens )) {
132
- throw new IllegalArgumentException ("Query tokens cannot be null or empty." );
133
- }
134
- float max = 0f ;
135
- for (Float value : queryTokens .values ()) {
136
- max = Math .max (value , max );
137
- }
138
- float threshold = max * thresholdRatio ;
139
-
140
- Map <Boolean , Map <String , Float >> queryTokensByScore = queryTokens .entrySet ()
141
- .stream ()
142
- .collect (
143
- Collectors .partitioningBy (entry -> entry .getValue () >= threshold , Collectors .toMap (Map .Entry ::getKey , Map .Entry ::getValue ))
144
- );
145
-
146
- Map <String , Float > highScoreTokens = queryTokensByScore .get (Boolean .TRUE );
147
- Map <String , Float > lowScoreTokens = queryTokensByScore .get (Boolean .FALSE );
148
- if (Objects .isNull (highScoreTokens )) {
149
- highScoreTokens = Collections .emptyMap ();
150
- }
151
- if (Objects .isNull (lowScoreTokens )) {
152
- lowScoreTokens = Collections .emptyMap ();
153
- }
154
- return Tuple .tuple (highScoreTokens , lowScoreTokens );
155
- }
156
-
157
115
private QueryBuilder getNestedQueryBuilderFromNeuralSparseQueryBuilderMap (
158
116
final Multimap <NeuralSparseQueryBuilder , Float > queryBuilderFloatMap
159
117
) {
@@ -201,7 +159,10 @@ private Multimap<NeuralSparseQueryBuilder, Float> collectNeuralSparseQueryBuilde
201
159
* - Docs besides TopDocs: Score = HighScoreToken's score
202
160
* - Final TopDocs: Score = HighScoreToken's score + LowScoreToken's score
203
161
*/
204
- NeuralSparseQueryBuilder modifiedQueryBuilder = neuralSparseQueryBuilder .getCopyNeuralSparseQueryBuilderForTwoPhase (ratio );
162
+ NeuralSparseQueryBuilder modifiedQueryBuilder = neuralSparseQueryBuilder .getCopyNeuralSparseQueryBuilderForTwoPhase (
163
+ pruneRatio ,
164
+ pruneType
165
+ );
205
166
result .put (modifiedQueryBuilder , updatedBoost );
206
167
}
207
168
// We only support BoostQuery, BooleanQuery and NeuralSparseQuery now. For other compound query type which are not support now, will
@@ -248,16 +209,40 @@ public NeuralSparseTwoPhaseProcessor create(
248
209
boolean enabled = ConfigurationUtils .readBooleanProperty (TYPE , tag , config , ENABLE_KEY , DEFAULT_ENABLED );
249
210
Map <String , Object > twoPhaseConfigMap = ConfigurationUtils .readOptionalMap (TYPE , tag , config , PARAMETER_KEY );
250
211
251
- float ratio = DEFAULT_RATIO ;
212
+ float pruneRatio = DEFAULT_RATIO ;
252
213
float windowExpansion = DEFAULT_WINDOW_EXPANSION ;
253
214
int maxWindowSize = DEFAULT_MAX_WINDOW_SIZE ;
215
+ PruneType pruneType = DEFAULT_PRUNE_TYPE ;
254
216
if (Objects .nonNull (twoPhaseConfigMap )) {
255
- ratio = ((Number ) twoPhaseConfigMap .getOrDefault (RATIO_KEY , ratio )).floatValue ();
217
+ pruneRatio = ((Number ) twoPhaseConfigMap .getOrDefault (PruneUtils . PRUNE_RATIO_FIELD , pruneRatio )).floatValue ();
256
218
windowExpansion = ((Number ) twoPhaseConfigMap .getOrDefault (EXPANSION_KEY , windowExpansion )).floatValue ();
257
219
maxWindowSize = ((Number ) twoPhaseConfigMap .getOrDefault (MAX_WINDOW_SIZE_KEY , maxWindowSize )).intValue ();
220
+ pruneType = PruneType .fromString (
221
+ twoPhaseConfigMap .getOrDefault (PruneUtils .PRUNE_TYPE_FIELD , pruneType .getValue ()).toString ()
222
+ );
223
+ }
224
+ if (!PruneUtils .isValidPruneRatio (pruneType , pruneRatio )) {
225
+ throw new IllegalArgumentException (
226
+ String .format (
227
+ Locale .ROOT ,
228
+ "Illegal prune_ratio %f for prune_type: %s. %s" ,
229
+ pruneRatio ,
230
+ pruneType .getValue (),
231
+ PruneUtils .getValidPruneRatioDescription (pruneType )
232
+ )
233
+ );
258
234
}
259
235
260
- return new NeuralSparseTwoPhaseProcessor (tag , description , ignoreFailure , enabled , ratio , windowExpansion , maxWindowSize );
236
+ return new NeuralSparseTwoPhaseProcessor (
237
+ tag ,
238
+ description ,
239
+ ignoreFailure ,
240
+ enabled ,
241
+ pruneRatio ,
242
+ pruneType ,
243
+ windowExpansion ,
244
+ maxWindowSize
245
+ );
261
246
}
262
247
}
263
248
0 commit comments