5
5
package org .opensearch .neuralsearch .query ;
6
6
7
7
import static org .opensearch .knn .index .query .KNNQueryBuilder .FILTER_FIELD ;
8
+ import static org .opensearch .knn .index .query .KNNQueryBuilder .MAX_DISTANCE_FIELD ;
9
+ import static org .opensearch .knn .index .query .KNNQueryBuilder .METHOD_PARAMS_FIELD ;
10
+ import static org .opensearch .knn .index .query .KNNQueryBuilder .MIN_SCORE_FIELD ;
11
+ import static org .opensearch .neuralsearch .common .MinClusterVersionUtil .isClusterOnOrAfterMinReqVersion ;
12
+ import static org .opensearch .neuralsearch .common .MinClusterVersionUtil .isClusterOnOrAfterMinReqVersionForDefaultModelIdSupport ;
13
+ import static org .opensearch .neuralsearch .common .MinClusterVersionUtil .isClusterOnOrAfterMinReqVersionForRadialSearch ;
8
14
import static org .opensearch .neuralsearch .common .VectorUtil .vectorAsListToArray ;
9
15
import static org .opensearch .neuralsearch .processor .TextImageEmbeddingProcessor .INPUT_IMAGE ;
10
16
import static org .opensearch .neuralsearch .processor .TextImageEmbeddingProcessor .INPUT_TEXT ;
19
25
import org .apache .commons .lang .builder .EqualsBuilder ;
20
26
import org .apache .commons .lang .builder .HashCodeBuilder ;
21
27
import org .apache .lucene .search .Query ;
22
- import org .opensearch .Version ;
23
28
import org .opensearch .common .SetOnce ;
24
29
import org .opensearch .core .ParseField ;
25
30
import org .opensearch .core .action .ActionListener ;
34
39
import org .opensearch .index .query .QueryRewriteContext ;
35
40
import org .opensearch .index .query .QueryShardContext ;
36
41
import org .opensearch .knn .index .query .KNNQueryBuilder ;
42
+ import org .opensearch .knn .index .query .parser .MethodParametersParser ;
43
+ import org .opensearch .neuralsearch .common .MinClusterVersionUtil ;
37
44
import org .opensearch .neuralsearch .ml .MLCommonsClientAccessor ;
38
- import org .opensearch .neuralsearch .util .NeuralSearchClusterUtil ;
39
45
40
46
import com .google .common .annotations .VisibleForTesting ;
41
47
@@ -69,18 +75,11 @@ public class NeuralQueryBuilder extends AbstractQueryBuilder<NeuralQueryBuilder>
69
75
@ VisibleForTesting
70
76
static final ParseField QUERY_IMAGE_FIELD = new ParseField ("query_image" );
71
77
72
- @ VisibleForTesting
73
- static final ParseField MODEL_ID_FIELD = new ParseField ("model_id" );
78
+ public static final ParseField MODEL_ID_FIELD = new ParseField ("model_id" );
74
79
75
80
@ VisibleForTesting
76
81
static final ParseField K_FIELD = new ParseField ("k" );
77
82
78
- @ VisibleForTesting
79
- static final ParseField MAX_DISTANCE_FIELD = new ParseField ("max_distance" );
80
-
81
- @ VisibleForTesting
82
- static final ParseField MIN_SCORE_FIELD = new ParseField ("min_score" );
83
-
84
83
private static final int DEFAULT_K = 10 ;
85
84
86
85
private static MLCommonsClientAccessor ML_CLIENT ;
@@ -101,8 +100,7 @@ public static void initialize(MLCommonsClientAccessor mlClient) {
101
100
@ Setter (AccessLevel .PACKAGE )
102
101
private Supplier <float []> vectorSupplier ;
103
102
private QueryBuilder filter ;
104
- private static final Version MINIMAL_SUPPORTED_VERSION_DEFAULT_MODEL_ID = Version .V_2_11_0 ;
105
- private static final Version MINIMAL_SUPPORTED_VERSION_RADIAL_SEARCH = Version .V_2_14_0 ;
103
+ private Map <String , ?> methodParameters ;
106
104
107
105
/**
108
106
* Constructor from stream input
@@ -130,6 +128,9 @@ public NeuralQueryBuilder(StreamInput in) throws IOException {
130
128
this .maxDistance = in .readOptionalFloat ();
131
129
this .minScore = in .readOptionalFloat ();
132
130
}
131
+ if (isClusterOnOrAfterMinReqVersion (METHOD_PARAMS_FIELD .getPreferredName ())) {
132
+ this .methodParameters = MethodParametersParser .streamInput (in , MinClusterVersionUtil ::isClusterOnOrAfterMinReqVersion );
133
+ }
133
134
}
134
135
135
136
@ Override
@@ -152,6 +153,9 @@ protected void doWriteTo(StreamOutput out) throws IOException {
152
153
out .writeOptionalFloat (this .maxDistance );
153
154
out .writeOptionalFloat (this .minScore );
154
155
}
156
+ if (isClusterOnOrAfterMinReqVersion (METHOD_PARAMS_FIELD .getPreferredName ())) {
157
+ MethodParametersParser .streamOutput (out , methodParameters , MinClusterVersionUtil ::isClusterOnOrAfterMinReqVersion );
158
+ }
155
159
}
156
160
157
161
@ Override
@@ -174,6 +178,9 @@ protected void doXContent(XContentBuilder xContentBuilder, Params params) throws
174
178
if (Objects .nonNull (minScore )) {
175
179
xContentBuilder .field (MIN_SCORE_FIELD .getPreferredName (), minScore );
176
180
}
181
+ if (Objects .nonNull (methodParameters )) {
182
+ MethodParametersParser .doXContent (xContentBuilder , methodParameters );
183
+ }
177
184
printBoostAndQueryName (xContentBuilder );
178
185
xContentBuilder .endObject ();
179
186
xContentBuilder .endObject ();
@@ -267,6 +274,8 @@ private static void parseQueryParams(XContentParser parser, NeuralQueryBuilder n
267
274
} else if (token == XContentParser .Token .START_OBJECT ) {
268
275
if (FILTER_FIELD .match (currentFieldName , parser .getDeprecationHandler ())) {
269
276
neuralQueryBuilder .filter (parseInnerQueryBuilder (parser ));
277
+ } else if (METHOD_PARAMS_FIELD .match (currentFieldName , parser .getDeprecationHandler ())) {
278
+ neuralQueryBuilder .methodParameters (MethodParametersParser .fromXContent (parser ));
270
279
}
271
280
} else {
272
281
throw new ParsingException (
@@ -292,15 +301,14 @@ protected QueryBuilder doRewrite(QueryRewriteContext queryRewriteContext) {
292
301
if (vectorSupplier ().get () == null ) {
293
302
return this ;
294
303
}
295
- KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder (fieldName (), vectorSupplier .get ()).filter (filter ());
296
- if (maxDistance != null ) {
297
- knnQueryBuilder .maxDistance (maxDistance );
298
- } else if (minScore != null ) {
299
- knnQueryBuilder .minScore (minScore );
300
- } else {
301
- knnQueryBuilder .k (k );
302
- }
303
- return knnQueryBuilder ;
304
+ return KNNQueryBuilder .builder ()
305
+ .fieldName (fieldName ())
306
+ .vector (vectorSupplier .get ())
307
+ .filter (filter ())
308
+ .maxDistance (maxDistance )
309
+ .minScore (minScore )
310
+ .k (k )
311
+ .build ();
304
312
}
305
313
306
314
SetOnce <float []> vectorSetOnce = new SetOnce <>();
@@ -326,7 +334,8 @@ protected QueryBuilder doRewrite(QueryRewriteContext queryRewriteContext) {
326
334
maxDistance (),
327
335
minScore (),
328
336
vectorSetOnce ::get ,
329
- filter ()
337
+ filter (),
338
+ methodParameters ()
330
339
);
331
340
}
332
341
@@ -359,14 +368,6 @@ public String getWriteableName() {
359
368
return NAME ;
360
369
}
361
370
362
- private static boolean isClusterOnOrAfterMinReqVersionForDefaultModelIdSupport () {
363
- return NeuralSearchClusterUtil .instance ().getClusterMinVersion ().onOrAfter (MINIMAL_SUPPORTED_VERSION_DEFAULT_MODEL_ID );
364
- }
365
-
366
- private static boolean isClusterOnOrAfterMinReqVersionForRadialSearch () {
367
- return NeuralSearchClusterUtil .instance ().getClusterMinVersion ().onOrAfter (MINIMAL_SUPPORTED_VERSION_RADIAL_SEARCH );
368
- }
369
-
370
371
private static boolean validateKNNQueryType (NeuralQueryBuilder neuralQueryBuilder ) {
371
372
int queryCount = 0 ;
372
373
if (neuralQueryBuilder .k () != null ) {
0 commit comments