Skip to content

Commit 726a482

Browse files
authored
Adds method_parameters in neural search query to support ef_search (#787)
Signed-off-by: Tejas Shah <shatejas@amazon.com>
1 parent 54ac672 commit 726a482

File tree

14 files changed

+193
-56
lines changed

14 files changed

+193
-56
lines changed

qa/restart-upgrade/src/test/java/org/opensearch/neuralsearch/bwc/HybridSearchIT.java

+12-3
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import java.util.Arrays;
1111
import java.util.List;
1212
import java.util.Map;
13+
1314
import org.opensearch.index.query.MatchQueryBuilder;
1415
import static org.opensearch.neuralsearch.util.TestUtils.getModelId;
1516
import static org.opensearch.neuralsearch.util.TestUtils.NODES_BWC_CLUSTER;
@@ -69,6 +70,7 @@ private void validateNormalizationProcessor(final String fileName, final String
6970
loadModel(modelId);
7071
addDocuments(getIndexNameForTest(), false);
7172
validateTestIndex(modelId, getIndexNameForTest(), searchPipelineName);
73+
validateTestIndex(modelId, getIndexNameForTest(), searchPipelineName, Map.of("ef_search", 100));
7274
} finally {
7375
wipeOfTestResources(getIndexNameForTest(), pipelineName, modelId, searchPipelineName);
7476
}
@@ -96,10 +98,14 @@ private void createSearchPipeline(final String pipelineName) {
9698
);
9799
}
98100

99-
private void validateTestIndex(final String modelId, final String index, final String searchPipeline) throws Exception {
101+
private void validateTestIndex(final String modelId, final String index, final String searchPipeline) {
102+
validateTestIndex(modelId, index, searchPipeline, null);
103+
}
104+
105+
private void validateTestIndex(final String modelId, final String index, final String searchPipeline, Map<String, ?> methodParameters) {
100106
int docCount = getDocCount(index);
101107
assertEquals(6, docCount);
102-
HybridQueryBuilder hybridQueryBuilder = getQueryBuilder(modelId);
108+
HybridQueryBuilder hybridQueryBuilder = getQueryBuilder(modelId, methodParameters);
103109
Map<String, Object> searchResponseAsMap = search(index, hybridQueryBuilder, null, 1, Map.of("search_pipeline", searchPipeline));
104110
assertNotNull(searchResponseAsMap);
105111
int hits = getHitCount(searchResponseAsMap);
@@ -110,12 +116,15 @@ private void validateTestIndex(final String modelId, final String index, final S
110116
}
111117
}
112118

113-
private HybridQueryBuilder getQueryBuilder(final String modelId) {
119+
private HybridQueryBuilder getQueryBuilder(final String modelId, Map<String, ?> methodParameters) {
114120
NeuralQueryBuilder neuralQueryBuilder = new NeuralQueryBuilder();
115121
neuralQueryBuilder.fieldName("passage_embedding");
116122
neuralQueryBuilder.modelId(modelId);
117123
neuralQueryBuilder.queryText(QUERY);
118124
neuralQueryBuilder.k(5);
125+
if (methodParameters != null) {
126+
neuralQueryBuilder.methodParameters(methodParameters);
127+
}
119128

120129
MatchQueryBuilder matchQueryBuilder = new MatchQueryBuilder("text", QUERY);
121130

qa/restart-upgrade/src/test/java/org/opensearch/neuralsearch/bwc/KnnRadialSearchIT.java

+2
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@ private void validateIndexQuery(final String modelId) {
6060
null,
6161
0.01f,
6262
null,
63+
null,
6364
null
6465
);
6566
Map<String, Object> responseWithMinScoreQuery = search(getIndexNameForTest(), neuralQueryBuilderWithMinScoreQuery, 1);
@@ -74,6 +75,7 @@ private void validateIndexQuery(final String modelId) {
7475
100000f,
7576
null,
7677
null,
78+
null,
7779
null
7880
);
7981
Map<String, Object> responseWithMaxDistanceQuery = search(getIndexNameForTest(), neuralQueryBuilderWithMaxDistanceQuery, 1);

qa/restart-upgrade/src/test/java/org/opensearch/neuralsearch/bwc/MultiModalSearchIT.java

+1
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@ private void validateTestIndex(final String modelId) throws Exception {
6262
null,
6363
null,
6464
null,
65+
null,
6566
null
6667
);
6768
Map<String, Object> response = search(getIndexNameForTest(), neuralQueryBuilder, 1);

qa/rolling-upgrade/src/test/java/org/opensearch/neuralsearch/bwc/HybridSearchIT.java

+11-2
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,7 @@ public void testNormalizationProcessor_whenIndexWithMultipleShards_E2EFlow() thr
7373
loadModel(modelId);
7474
addDocument(getIndexNameForTest(), "2", TEST_FIELD, TEXT_UPGRADED, null, null);
7575
validateTestIndexOnUpgrade(totalDocsCountUpgraded, modelId);
76+
validateTestIndexOnUpgrade(totalDocsCountUpgraded, modelId, Map.of("ef_search", 100));
7677
} finally {
7778
wipeOfTestResources(getIndexNameForTest(), PIPELINE_NAME, modelId, SEARCH_PIPELINE_NAME);
7879
}
@@ -83,10 +84,15 @@ public void testNormalizationProcessor_whenIndexWithMultipleShards_E2EFlow() thr
8384
}
8485

8586
private void validateTestIndexOnUpgrade(final int numberOfDocs, final String modelId) throws Exception {
87+
validateTestIndexOnUpgrade(numberOfDocs, modelId, null);
88+
}
89+
90+
private void validateTestIndexOnUpgrade(final int numberOfDocs, final String modelId, Map<String, ?> methodParameters)
91+
throws Exception {
8692
int docCount = getDocCount(getIndexNameForTest());
8793
assertEquals(numberOfDocs, docCount);
8894
loadModel(modelId);
89-
HybridQueryBuilder hybridQueryBuilder = getQueryBuilder(modelId);
95+
HybridQueryBuilder hybridQueryBuilder = getQueryBuilder(modelId, methodParameters);
9096
Map<String, Object> searchResponseAsMap = search(
9197
getIndexNameForTest(),
9298
hybridQueryBuilder,
@@ -103,12 +109,15 @@ private void validateTestIndexOnUpgrade(final int numberOfDocs, final String mod
103109
}
104110
}
105111

106-
private HybridQueryBuilder getQueryBuilder(final String modelId) {
112+
private HybridQueryBuilder getQueryBuilder(final String modelId, final Map<String, ?> methodParameters) {
107113
NeuralQueryBuilder neuralQueryBuilder = new NeuralQueryBuilder();
108114
neuralQueryBuilder.fieldName("passage_embedding");
109115
neuralQueryBuilder.modelId(modelId);
110116
neuralQueryBuilder.queryText(QUERY);
111117
neuralQueryBuilder.k(5);
118+
if (methodParameters != null) {
119+
neuralQueryBuilder.methodParameters(methodParameters);
120+
}
112121

113122
MatchQueryBuilder matchQueryBuilder = new MatchQueryBuilder("text", QUERY);
114123

qa/rolling-upgrade/src/test/java/org/opensearch/neuralsearch/bwc/KnnRadialSearchIT.java

+2
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,7 @@ private void validateIndexQueryOnUpgrade(final int numberOfDocs, final String mo
8686
null,
8787
0.01f,
8888
null,
89+
null,
8990
null
9091
);
9192
Map<String, Object> responseWithMinScore = search(getIndexNameForTest(), neuralQueryBuilderWithMinScoreQuery, 1);
@@ -100,6 +101,7 @@ private void validateIndexQueryOnUpgrade(final int numberOfDocs, final String mo
100101
100000f,
101102
null,
102103
null,
104+
null,
103105
null
104106
);
105107
Map<String, Object> responseWithMaxScore = search(getIndexNameForTest(), neuralQueryBuilderWithMaxDistanceQuery, 1);

qa/rolling-upgrade/src/test/java/org/opensearch/neuralsearch/bwc/MultiModalSearchIT.java

+1
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,7 @@ private void validateTestIndexOnUpgrade(final int numberOfDocs, final String mod
8585
null,
8686
null,
8787
null,
88+
null,
8889
null
8990
);
9091
Map<String, Object> responseWithKQuery = search(getIndexNameForTest(), neuralQueryBuilderWithKQuery, 1);
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
/*
2+
* Copyright OpenSearch Contributors
3+
* SPDX-License-Identifier: Apache-2.0
4+
*/
5+
package org.opensearch.neuralsearch.common;
6+
7+
import com.google.common.collect.ImmutableMap;
8+
import org.opensearch.Version;
9+
import org.opensearch.knn.index.IndexUtil;
10+
import org.opensearch.neuralsearch.util.NeuralSearchClusterUtil;
11+
12+
import java.util.Map;
13+
14+
import static org.opensearch.knn.index.query.KNNQueryBuilder.MAX_DISTANCE_FIELD;
15+
import static org.opensearch.knn.index.query.KNNQueryBuilder.MIN_SCORE_FIELD;
16+
import static org.opensearch.neuralsearch.query.NeuralQueryBuilder.MODEL_ID_FIELD;
17+
18+
/**
19+
* A util class which holds the logic to determine the min version supported by the request parameters
20+
*/
21+
public final class MinClusterVersionUtil {
22+
23+
private static final Version MINIMAL_SUPPORTED_VERSION_DEFAULT_MODEL_ID = Version.V_2_11_0;
24+
private static final Version MINIMAL_SUPPORTED_VERSION_RADIAL_SEARCH = Version.V_2_14_0;
25+
26+
// Note this minimal version will act as a override
27+
private static final Map<String, Version> MINIMAL_VERSION_NEURAL = ImmutableMap.<String, Version>builder()
28+
.put(MODEL_ID_FIELD.getPreferredName(), MINIMAL_SUPPORTED_VERSION_DEFAULT_MODEL_ID)
29+
.put(MAX_DISTANCE_FIELD.getPreferredName(), MINIMAL_SUPPORTED_VERSION_RADIAL_SEARCH)
30+
.put(MIN_SCORE_FIELD.getPreferredName(), MINIMAL_SUPPORTED_VERSION_RADIAL_SEARCH)
31+
.build();
32+
33+
public static boolean isClusterOnOrAfterMinReqVersionForDefaultModelIdSupport() {
34+
return NeuralSearchClusterUtil.instance().getClusterMinVersion().onOrAfter(MINIMAL_SUPPORTED_VERSION_DEFAULT_MODEL_ID);
35+
}
36+
37+
public static boolean isClusterOnOrAfterMinReqVersionForRadialSearch() {
38+
return NeuralSearchClusterUtil.instance().getClusterMinVersion().onOrAfter(MINIMAL_SUPPORTED_VERSION_RADIAL_SEARCH);
39+
}
40+
41+
public static boolean isClusterOnOrAfterMinReqVersion(String key) {
42+
Version version;
43+
if (MINIMAL_VERSION_NEURAL.containsKey(key)) {
44+
version = MINIMAL_VERSION_NEURAL.get(key);
45+
} else {
46+
version = IndexUtil.minimalRequiredVersionMap.get(key);
47+
}
48+
return NeuralSearchClusterUtil.instance().getClusterMinVersion().onOrAfter(version);
49+
}
50+
}

src/main/java/org/opensearch/neuralsearch/query/NeuralQueryBuilder.java

+31-30
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,12 @@
55
package org.opensearch.neuralsearch.query;
66

77
import static org.opensearch.knn.index.query.KNNQueryBuilder.FILTER_FIELD;
8+
import static org.opensearch.knn.index.query.KNNQueryBuilder.MAX_DISTANCE_FIELD;
9+
import static org.opensearch.knn.index.query.KNNQueryBuilder.METHOD_PARAMS_FIELD;
10+
import static org.opensearch.knn.index.query.KNNQueryBuilder.MIN_SCORE_FIELD;
11+
import static org.opensearch.neuralsearch.common.MinClusterVersionUtil.isClusterOnOrAfterMinReqVersion;
12+
import static org.opensearch.neuralsearch.common.MinClusterVersionUtil.isClusterOnOrAfterMinReqVersionForDefaultModelIdSupport;
13+
import static org.opensearch.neuralsearch.common.MinClusterVersionUtil.isClusterOnOrAfterMinReqVersionForRadialSearch;
814
import static org.opensearch.neuralsearch.common.VectorUtil.vectorAsListToArray;
915
import static org.opensearch.neuralsearch.processor.TextImageEmbeddingProcessor.INPUT_IMAGE;
1016
import static org.opensearch.neuralsearch.processor.TextImageEmbeddingProcessor.INPUT_TEXT;
@@ -19,7 +25,6 @@
1925
import org.apache.commons.lang.builder.EqualsBuilder;
2026
import org.apache.commons.lang.builder.HashCodeBuilder;
2127
import org.apache.lucene.search.Query;
22-
import org.opensearch.Version;
2328
import org.opensearch.common.SetOnce;
2429
import org.opensearch.core.ParseField;
2530
import org.opensearch.core.action.ActionListener;
@@ -34,8 +39,9 @@
3439
import org.opensearch.index.query.QueryRewriteContext;
3540
import org.opensearch.index.query.QueryShardContext;
3641
import org.opensearch.knn.index.query.KNNQueryBuilder;
42+
import org.opensearch.knn.index.query.parser.MethodParametersParser;
43+
import org.opensearch.neuralsearch.common.MinClusterVersionUtil;
3744
import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor;
38-
import org.opensearch.neuralsearch.util.NeuralSearchClusterUtil;
3945

4046
import com.google.common.annotations.VisibleForTesting;
4147

@@ -69,18 +75,11 @@ public class NeuralQueryBuilder extends AbstractQueryBuilder<NeuralQueryBuilder>
6975
@VisibleForTesting
7076
static final ParseField QUERY_IMAGE_FIELD = new ParseField("query_image");
7177

72-
@VisibleForTesting
73-
static final ParseField MODEL_ID_FIELD = new ParseField("model_id");
78+
public static final ParseField MODEL_ID_FIELD = new ParseField("model_id");
7479

7580
@VisibleForTesting
7681
static final ParseField K_FIELD = new ParseField("k");
7782

78-
@VisibleForTesting
79-
static final ParseField MAX_DISTANCE_FIELD = new ParseField("max_distance");
80-
81-
@VisibleForTesting
82-
static final ParseField MIN_SCORE_FIELD = new ParseField("min_score");
83-
8483
private static final int DEFAULT_K = 10;
8584

8685
private static MLCommonsClientAccessor ML_CLIENT;
@@ -101,8 +100,7 @@ public static void initialize(MLCommonsClientAccessor mlClient) {
101100
@Setter(AccessLevel.PACKAGE)
102101
private Supplier<float[]> vectorSupplier;
103102
private QueryBuilder filter;
104-
private static final Version MINIMAL_SUPPORTED_VERSION_DEFAULT_MODEL_ID = Version.V_2_11_0;
105-
private static final Version MINIMAL_SUPPORTED_VERSION_RADIAL_SEARCH = Version.V_2_14_0;
103+
private Map<String, ?> methodParameters;
106104

107105
/**
108106
* Constructor from stream input
@@ -130,6 +128,9 @@ public NeuralQueryBuilder(StreamInput in) throws IOException {
130128
this.maxDistance = in.readOptionalFloat();
131129
this.minScore = in.readOptionalFloat();
132130
}
131+
if (isClusterOnOrAfterMinReqVersion(METHOD_PARAMS_FIELD.getPreferredName())) {
132+
this.methodParameters = MethodParametersParser.streamInput(in, MinClusterVersionUtil::isClusterOnOrAfterMinReqVersion);
133+
}
133134
}
134135

135136
@Override
@@ -152,6 +153,9 @@ protected void doWriteTo(StreamOutput out) throws IOException {
152153
out.writeOptionalFloat(this.maxDistance);
153154
out.writeOptionalFloat(this.minScore);
154155
}
156+
if (isClusterOnOrAfterMinReqVersion(METHOD_PARAMS_FIELD.getPreferredName())) {
157+
MethodParametersParser.streamOutput(out, methodParameters, MinClusterVersionUtil::isClusterOnOrAfterMinReqVersion);
158+
}
155159
}
156160

157161
@Override
@@ -174,6 +178,9 @@ protected void doXContent(XContentBuilder xContentBuilder, Params params) throws
174178
if (Objects.nonNull(minScore)) {
175179
xContentBuilder.field(MIN_SCORE_FIELD.getPreferredName(), minScore);
176180
}
181+
if (Objects.nonNull(methodParameters)) {
182+
MethodParametersParser.doXContent(xContentBuilder, methodParameters);
183+
}
177184
printBoostAndQueryName(xContentBuilder);
178185
xContentBuilder.endObject();
179186
xContentBuilder.endObject();
@@ -267,6 +274,8 @@ private static void parseQueryParams(XContentParser parser, NeuralQueryBuilder n
267274
} else if (token == XContentParser.Token.START_OBJECT) {
268275
if (FILTER_FIELD.match(currentFieldName, parser.getDeprecationHandler())) {
269276
neuralQueryBuilder.filter(parseInnerQueryBuilder(parser));
277+
} else if (METHOD_PARAMS_FIELD.match(currentFieldName, parser.getDeprecationHandler())) {
278+
neuralQueryBuilder.methodParameters(MethodParametersParser.fromXContent(parser));
270279
}
271280
} else {
272281
throw new ParsingException(
@@ -292,15 +301,14 @@ protected QueryBuilder doRewrite(QueryRewriteContext queryRewriteContext) {
292301
if (vectorSupplier().get() == null) {
293302
return this;
294303
}
295-
KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder(fieldName(), vectorSupplier.get()).filter(filter());
296-
if (maxDistance != null) {
297-
knnQueryBuilder.maxDistance(maxDistance);
298-
} else if (minScore != null) {
299-
knnQueryBuilder.minScore(minScore);
300-
} else {
301-
knnQueryBuilder.k(k);
302-
}
303-
return knnQueryBuilder;
304+
return KNNQueryBuilder.builder()
305+
.fieldName(fieldName())
306+
.vector(vectorSupplier.get())
307+
.filter(filter())
308+
.maxDistance(maxDistance)
309+
.minScore(minScore)
310+
.k(k)
311+
.build();
304312
}
305313

306314
SetOnce<float[]> vectorSetOnce = new SetOnce<>();
@@ -326,7 +334,8 @@ protected QueryBuilder doRewrite(QueryRewriteContext queryRewriteContext) {
326334
maxDistance(),
327335
minScore(),
328336
vectorSetOnce::get,
329-
filter()
337+
filter(),
338+
methodParameters()
330339
);
331340
}
332341

@@ -359,14 +368,6 @@ public String getWriteableName() {
359368
return NAME;
360369
}
361370

362-
private static boolean isClusterOnOrAfterMinReqVersionForDefaultModelIdSupport() {
363-
return NeuralSearchClusterUtil.instance().getClusterMinVersion().onOrAfter(MINIMAL_SUPPORTED_VERSION_DEFAULT_MODEL_ID);
364-
}
365-
366-
private static boolean isClusterOnOrAfterMinReqVersionForRadialSearch() {
367-
return NeuralSearchClusterUtil.instance().getClusterMinVersion().onOrAfter(MINIMAL_SUPPORTED_VERSION_RADIAL_SEARCH);
368-
}
369-
370371
private static boolean validateKNNQueryType(NeuralQueryBuilder neuralQueryBuilder) {
371372
int queryCount = 0;
372373
if (neuralQueryBuilder.k() != null) {

0 commit comments

Comments
 (0)