5
5
package org .opensearch .neuralsearch .query ;
6
6
7
7
import static org .opensearch .knn .index .query .KNNQueryBuilder .FILTER_FIELD ;
8
+ import static org .opensearch .knn .index .query .KNNQueryBuilder .MAX_DISTANCE_FIELD ;
9
+ import static org .opensearch .knn .index .query .KNNQueryBuilder .METHOD_PARAMS_FIELD ;
10
+ import static org .opensearch .knn .index .query .KNNQueryBuilder .MIN_SCORE_FIELD ;
11
+ import static org .opensearch .neuralsearch .common .MinClusterVersionUtil .isClusterOnOrAfterMinReqVersion ;
12
+ import static org .opensearch .neuralsearch .common .MinClusterVersionUtil .isClusterOnOrAfterMinReqVersionForDefaultModelIdSupport ;
13
+ import static org .opensearch .neuralsearch .common .MinClusterVersionUtil .isClusterOnOrAfterMinReqVersionForRadialSearch ;
8
14
import static org .opensearch .neuralsearch .common .VectorUtil .vectorAsListToArray ;
9
15
import static org .opensearch .neuralsearch .processor .TextImageEmbeddingProcessor .INPUT_IMAGE ;
10
16
import static org .opensearch .neuralsearch .processor .TextImageEmbeddingProcessor .INPUT_TEXT ;
19
25
import org .apache .commons .lang .builder .EqualsBuilder ;
20
26
import org .apache .commons .lang .builder .HashCodeBuilder ;
21
27
import org .apache .lucene .search .Query ;
22
- import org .opensearch .Version ;
23
28
import org .opensearch .common .SetOnce ;
24
29
import org .opensearch .core .ParseField ;
25
30
import org .opensearch .core .action .ActionListener ;
34
39
import org .opensearch .index .query .QueryRewriteContext ;
35
40
import org .opensearch .index .query .QueryShardContext ;
36
41
import org .opensearch .knn .index .query .KNNQueryBuilder ;
42
+ import org .opensearch .knn .index .query .parser .MethodParametersParser ;
43
+ import org .opensearch .neuralsearch .common .MinClusterVersionUtil ;
37
44
import org .opensearch .neuralsearch .ml .MLCommonsClientAccessor ;
38
- import org .opensearch .neuralsearch .util .NeuralSearchClusterUtil ;
39
45
40
46
import com .google .common .annotations .VisibleForTesting ;
41
47
@@ -69,18 +75,11 @@ public class NeuralQueryBuilder extends AbstractQueryBuilder<NeuralQueryBuilder>
69
75
@ VisibleForTesting
70
76
static final ParseField QUERY_IMAGE_FIELD = new ParseField ("query_image" );
71
77
72
- @ VisibleForTesting
73
- static final ParseField MODEL_ID_FIELD = new ParseField ("model_id" );
78
+ public static final ParseField MODEL_ID_FIELD = new ParseField ("model_id" );
74
79
75
80
@ VisibleForTesting
76
81
static final ParseField K_FIELD = new ParseField ("k" );
77
82
78
- @ VisibleForTesting
79
- static final ParseField MAX_DISTANCE_FIELD = new ParseField ("max_distance" );
80
-
81
- @ VisibleForTesting
82
- static final ParseField MIN_SCORE_FIELD = new ParseField ("min_score" );
83
-
84
83
private static final int DEFAULT_K = 10 ;
85
84
86
85
private static MLCommonsClientAccessor ML_CLIENT ;
@@ -101,8 +100,7 @@ public static void initialize(MLCommonsClientAccessor mlClient) {
101
100
@ Setter (AccessLevel .PACKAGE )
102
101
private Supplier <float []> vectorSupplier ;
103
102
private QueryBuilder filter ;
104
- private static final Version MINIMAL_SUPPORTED_VERSION_DEFAULT_MODEL_ID = Version .V_2_11_0 ;
105
- private static final Version MINIMAL_SUPPORTED_VERSION_RADIAL_SEARCH = Version .V_2_14_0 ;
103
+ private Map <String , ?> methodParameters ;
106
104
107
105
/**
108
106
* Constructor from stream input
@@ -130,6 +128,9 @@ public NeuralQueryBuilder(StreamInput in) throws IOException {
130
128
this .maxDistance = in .readOptionalFloat ();
131
129
this .minScore = in .readOptionalFloat ();
132
130
}
131
+ if (isClusterOnOrAfterMinReqVersion (METHOD_PARAMS_FIELD .getPreferredName ())) {
132
+ this .methodParameters = MethodParametersParser .streamInput (in , MinClusterVersionUtil ::isClusterOnOrAfterMinReqVersion );
133
+ }
133
134
}
134
135
135
136
@ Override
@@ -152,6 +153,9 @@ protected void doWriteTo(StreamOutput out) throws IOException {
152
153
out .writeOptionalFloat (this .maxDistance );
153
154
out .writeOptionalFloat (this .minScore );
154
155
}
156
+ if (isClusterOnOrAfterMinReqVersion (METHOD_PARAMS_FIELD .getPreferredName ())) {
157
+ MethodParametersParser .streamOutput (out , methodParameters , MinClusterVersionUtil ::isClusterOnOrAfterMinReqVersion );
158
+ }
155
159
}
156
160
157
161
@ Override
@@ -174,6 +178,9 @@ protected void doXContent(XContentBuilder xContentBuilder, Params params) throws
174
178
if (Objects .nonNull (minScore )) {
175
179
xContentBuilder .field (MIN_SCORE_FIELD .getPreferredName (), minScore );
176
180
}
181
+ if (Objects .nonNull (methodParameters )) {
182
+ MethodParametersParser .doXContent (xContentBuilder , methodParameters );
183
+ }
177
184
printBoostAndQueryName (xContentBuilder );
178
185
xContentBuilder .endObject ();
179
186
xContentBuilder .endObject ();
@@ -267,6 +274,8 @@ private static void parseQueryParams(XContentParser parser, NeuralQueryBuilder n
267
274
} else if (token == XContentParser .Token .START_OBJECT ) {
268
275
if (FILTER_FIELD .match (currentFieldName , parser .getDeprecationHandler ())) {
269
276
neuralQueryBuilder .filter (parseInnerQueryBuilder (parser ));
277
+ } else if (METHOD_PARAMS_FIELD .match (currentFieldName , parser .getDeprecationHandler ())) {
278
+ neuralQueryBuilder .methodParameters (MethodParametersParser .fromXContent (parser ));
270
279
}
271
280
} else {
272
281
throw new ParsingException (
@@ -325,7 +334,8 @@ protected QueryBuilder doRewrite(QueryRewriteContext queryRewriteContext) {
325
334
maxDistance (),
326
335
minScore (),
327
336
vectorSetOnce ::get ,
328
- filter ()
337
+ filter (),
338
+ methodParameters ()
329
339
);
330
340
}
331
341
@@ -358,14 +368,6 @@ public String getWriteableName() {
358
368
return NAME ;
359
369
}
360
370
361
- private static boolean isClusterOnOrAfterMinReqVersionForDefaultModelIdSupport () {
362
- return NeuralSearchClusterUtil .instance ().getClusterMinVersion ().onOrAfter (MINIMAL_SUPPORTED_VERSION_DEFAULT_MODEL_ID );
363
- }
364
-
365
- private static boolean isClusterOnOrAfterMinReqVersionForRadialSearch () {
366
- return NeuralSearchClusterUtil .instance ().getClusterMinVersion ().onOrAfter (MINIMAL_SUPPORTED_VERSION_RADIAL_SEARCH );
367
- }
368
-
369
371
private static boolean validateKNNQueryType (NeuralQueryBuilder neuralQueryBuilder ) {
370
372
int queryCount = 0 ;
371
373
if (neuralQueryBuilder .k () != null ) {
0 commit comments