Skip to content

Commit 935829a

Browse files
shatejasvibrantvarun
authored andcommitted
Adds method_parameters in neural search query to support ef_search (opensearch-project#787) (opensearch-project#814)
Signed-off-by: Tejas Shah <shatejas@amazon.com>
1 parent ded2788 commit 935829a

File tree

14 files changed

+165
-40
lines changed

14 files changed

+165
-40
lines changed

CHANGELOG.md

+1
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
1515
## [Unreleased 2.x](https://github.com/opensearch-project/neural-search/compare/2.15...2.x)
1616
### Features
1717
### Enhancements
18+
* Adds dynamic knn query parameters efsearch and nprobes [#814](https://github.com/opensearch-project/neural-search/pull/814/)
1819
### Bug Fixes
1920
### Infrastructure
2021
### Documentation

qa/restart-upgrade/src/test/java/org/opensearch/neuralsearch/bwc/HybridSearchIT.java

+12-3
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import java.util.Arrays;
1111
import java.util.List;
1212
import java.util.Map;
13+
1314
import org.opensearch.index.query.MatchQueryBuilder;
1415
import static org.opensearch.neuralsearch.util.TestUtils.getModelId;
1516
import static org.opensearch.neuralsearch.util.TestUtils.NODES_BWC_CLUSTER;
@@ -69,6 +70,7 @@ private void validateNormalizationProcessor(final String fileName, final String
6970
loadModel(modelId);
7071
addDocuments(getIndexNameForTest(), false);
7172
validateTestIndex(modelId, getIndexNameForTest(), searchPipelineName);
73+
validateTestIndex(modelId, getIndexNameForTest(), searchPipelineName, Map.of("ef_search", 100));
7274
} finally {
7375
wipeOfTestResources(getIndexNameForTest(), pipelineName, modelId, searchPipelineName);
7476
}
@@ -96,10 +98,14 @@ private void createSearchPipeline(final String pipelineName) {
9698
);
9799
}
98100

99-
private void validateTestIndex(final String modelId, final String index, final String searchPipeline) throws Exception {
101+
private void validateTestIndex(final String modelId, final String index, final String searchPipeline) {
102+
validateTestIndex(modelId, index, searchPipeline, null);
103+
}
104+
105+
private void validateTestIndex(final String modelId, final String index, final String searchPipeline, Map<String, ?> methodParameters) {
100106
int docCount = getDocCount(index);
101107
assertEquals(6, docCount);
102-
HybridQueryBuilder hybridQueryBuilder = getQueryBuilder(modelId);
108+
HybridQueryBuilder hybridQueryBuilder = getQueryBuilder(modelId, methodParameters);
103109
Map<String, Object> searchResponseAsMap = search(index, hybridQueryBuilder, null, 1, Map.of("search_pipeline", searchPipeline));
104110
assertNotNull(searchResponseAsMap);
105111
int hits = getHitCount(searchResponseAsMap);
@@ -110,12 +116,15 @@ private void validateTestIndex(final String modelId, final String index, final S
110116
}
111117
}
112118

113-
private HybridQueryBuilder getQueryBuilder(final String modelId) {
119+
private HybridQueryBuilder getQueryBuilder(final String modelId, Map<String, ?> methodParameters) {
114120
NeuralQueryBuilder neuralQueryBuilder = new NeuralQueryBuilder();
115121
neuralQueryBuilder.fieldName("passage_embedding");
116122
neuralQueryBuilder.modelId(modelId);
117123
neuralQueryBuilder.queryText(QUERY);
118124
neuralQueryBuilder.k(5);
125+
if (methodParameters != null) {
126+
neuralQueryBuilder.methodParameters(methodParameters);
127+
}
119128

120129
MatchQueryBuilder matchQueryBuilder = new MatchQueryBuilder("text", QUERY);
121130

qa/restart-upgrade/src/test/java/org/opensearch/neuralsearch/bwc/KnnRadialSearchIT.java

+2
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@ private void validateIndexQuery(final String modelId) {
6060
null,
6161
0.01f,
6262
null,
63+
null,
6364
null
6465
);
6566
Map<String, Object> responseWithMinScoreQuery = search(getIndexNameForTest(), neuralQueryBuilderWithMinScoreQuery, 1);
@@ -74,6 +75,7 @@ private void validateIndexQuery(final String modelId) {
7475
100000f,
7576
null,
7677
null,
78+
null,
7779
null
7880
);
7981
Map<String, Object> responseWithMaxDistanceQuery = search(getIndexNameForTest(), neuralQueryBuilderWithMaxDistanceQuery, 1);

qa/restart-upgrade/src/test/java/org/opensearch/neuralsearch/bwc/MultiModalSearchIT.java

+1
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@ private void validateTestIndex(final String modelId) throws Exception {
6262
null,
6363
null,
6464
null,
65+
null,
6566
null
6667
);
6768
Map<String, Object> response = search(getIndexNameForTest(), neuralQueryBuilder, 1);

qa/rolling-upgrade/src/test/java/org/opensearch/neuralsearch/bwc/HybridSearchIT.java

+11-2
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,7 @@ public void testNormalizationProcessor_whenIndexWithMultipleShards_E2EFlow() thr
7373
loadModel(modelId);
7474
addDocument(getIndexNameForTest(), "2", TEST_FIELD, TEXT_UPGRADED, null, null);
7575
validateTestIndexOnUpgrade(totalDocsCountUpgraded, modelId);
76+
validateTestIndexOnUpgrade(totalDocsCountUpgraded, modelId, Map.of("ef_search", 100));
7677
} finally {
7778
wipeOfTestResources(getIndexNameForTest(), PIPELINE_NAME, modelId, SEARCH_PIPELINE_NAME);
7879
}
@@ -83,10 +84,15 @@ public void testNormalizationProcessor_whenIndexWithMultipleShards_E2EFlow() thr
8384
}
8485

8586
private void validateTestIndexOnUpgrade(final int numberOfDocs, final String modelId) throws Exception {
87+
validateTestIndexOnUpgrade(numberOfDocs, modelId, null);
88+
}
89+
90+
private void validateTestIndexOnUpgrade(final int numberOfDocs, final String modelId, Map<String, ?> methodParameters)
91+
throws Exception {
8692
int docCount = getDocCount(getIndexNameForTest());
8793
assertEquals(numberOfDocs, docCount);
8894
loadModel(modelId);
89-
HybridQueryBuilder hybridQueryBuilder = getQueryBuilder(modelId);
95+
HybridQueryBuilder hybridQueryBuilder = getQueryBuilder(modelId, methodParameters);
9096
Map<String, Object> searchResponseAsMap = search(
9197
getIndexNameForTest(),
9298
hybridQueryBuilder,
@@ -103,12 +109,15 @@ private void validateTestIndexOnUpgrade(final int numberOfDocs, final String mod
103109
}
104110
}
105111

106-
private HybridQueryBuilder getQueryBuilder(final String modelId) {
112+
private HybridQueryBuilder getQueryBuilder(final String modelId, final Map<String, ?> methodParameters) {
107113
NeuralQueryBuilder neuralQueryBuilder = new NeuralQueryBuilder();
108114
neuralQueryBuilder.fieldName("passage_embedding");
109115
neuralQueryBuilder.modelId(modelId);
110116
neuralQueryBuilder.queryText(QUERY);
111117
neuralQueryBuilder.k(5);
118+
if (methodParameters != null) {
119+
neuralQueryBuilder.methodParameters(methodParameters);
120+
}
112121

113122
MatchQueryBuilder matchQueryBuilder = new MatchQueryBuilder("text", QUERY);
114123

qa/rolling-upgrade/src/test/java/org/opensearch/neuralsearch/bwc/KnnRadialSearchIT.java

+2
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,7 @@ private void validateIndexQueryOnUpgrade(final int numberOfDocs, final String mo
8686
null,
8787
0.01f,
8888
null,
89+
null,
8990
null
9091
);
9192
Map<String, Object> responseWithMinScore = search(getIndexNameForTest(), neuralQueryBuilderWithMinScoreQuery, 1);
@@ -100,6 +101,7 @@ private void validateIndexQueryOnUpgrade(final int numberOfDocs, final String mo
100101
100000f,
101102
null,
102103
null,
104+
null,
103105
null
104106
);
105107
Map<String, Object> responseWithMaxScore = search(getIndexNameForTest(), neuralQueryBuilderWithMaxDistanceQuery, 1);

qa/rolling-upgrade/src/test/java/org/opensearch/neuralsearch/bwc/MultiModalSearchIT.java

+1
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,7 @@ private void validateTestIndexOnUpgrade(final int numberOfDocs, final String mod
8585
null,
8686
null,
8787
null,
88+
null,
8889
null
8990
);
9091
Map<String, Object> responseWithKQuery = search(getIndexNameForTest(), neuralQueryBuilderWithKQuery, 1);
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
/*
2+
* Copyright OpenSearch Contributors
3+
* SPDX-License-Identifier: Apache-2.0
4+
*/
5+
package org.opensearch.neuralsearch.common;
6+
7+
import com.google.common.collect.ImmutableMap;
8+
import org.opensearch.Version;
9+
import org.opensearch.knn.index.IndexUtil;
10+
import org.opensearch.neuralsearch.util.NeuralSearchClusterUtil;
11+
12+
import java.util.Map;
13+
14+
import static org.opensearch.knn.index.query.KNNQueryBuilder.MAX_DISTANCE_FIELD;
15+
import static org.opensearch.knn.index.query.KNNQueryBuilder.MIN_SCORE_FIELD;
16+
import static org.opensearch.neuralsearch.query.NeuralQueryBuilder.MODEL_ID_FIELD;
17+
18+
/**
19+
* A util class which holds the logic to determine the min version supported by the request parameters
20+
*/
21+
public final class MinClusterVersionUtil {
22+
23+
private static final Version MINIMAL_SUPPORTED_VERSION_DEFAULT_MODEL_ID = Version.V_2_11_0;
24+
private static final Version MINIMAL_SUPPORTED_VERSION_RADIAL_SEARCH = Version.V_2_14_0;
25+
26+
// Note this minimal version will act as a override
27+
private static final Map<String, Version> MINIMAL_VERSION_NEURAL = ImmutableMap.<String, Version>builder()
28+
.put(MODEL_ID_FIELD.getPreferredName(), MINIMAL_SUPPORTED_VERSION_DEFAULT_MODEL_ID)
29+
.put(MAX_DISTANCE_FIELD.getPreferredName(), MINIMAL_SUPPORTED_VERSION_RADIAL_SEARCH)
30+
.put(MIN_SCORE_FIELD.getPreferredName(), MINIMAL_SUPPORTED_VERSION_RADIAL_SEARCH)
31+
.build();
32+
33+
public static boolean isClusterOnOrAfterMinReqVersionForDefaultModelIdSupport() {
34+
return NeuralSearchClusterUtil.instance().getClusterMinVersion().onOrAfter(MINIMAL_SUPPORTED_VERSION_DEFAULT_MODEL_ID);
35+
}
36+
37+
public static boolean isClusterOnOrAfterMinReqVersionForRadialSearch() {
38+
return NeuralSearchClusterUtil.instance().getClusterMinVersion().onOrAfter(MINIMAL_SUPPORTED_VERSION_RADIAL_SEARCH);
39+
}
40+
41+
public static boolean isClusterOnOrAfterMinReqVersion(String key) {
42+
Version version;
43+
if (MINIMAL_VERSION_NEURAL.containsKey(key)) {
44+
version = MINIMAL_VERSION_NEURAL.get(key);
45+
} else {
46+
version = IndexUtil.minimalRequiredVersionMap.get(key);
47+
}
48+
return NeuralSearchClusterUtil.instance().getClusterMinVersion().onOrAfter(version);
49+
}
50+
}

src/main/java/org/opensearch/neuralsearch/query/NeuralQueryBuilder.java

+23-21
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,12 @@
55
package org.opensearch.neuralsearch.query;
66

77
import static org.opensearch.knn.index.query.KNNQueryBuilder.FILTER_FIELD;
8+
import static org.opensearch.knn.index.query.KNNQueryBuilder.MAX_DISTANCE_FIELD;
9+
import static org.opensearch.knn.index.query.KNNQueryBuilder.METHOD_PARAMS_FIELD;
10+
import static org.opensearch.knn.index.query.KNNQueryBuilder.MIN_SCORE_FIELD;
11+
import static org.opensearch.neuralsearch.common.MinClusterVersionUtil.isClusterOnOrAfterMinReqVersion;
12+
import static org.opensearch.neuralsearch.common.MinClusterVersionUtil.isClusterOnOrAfterMinReqVersionForDefaultModelIdSupport;
13+
import static org.opensearch.neuralsearch.common.MinClusterVersionUtil.isClusterOnOrAfterMinReqVersionForRadialSearch;
814
import static org.opensearch.neuralsearch.common.VectorUtil.vectorAsListToArray;
915
import static org.opensearch.neuralsearch.processor.TextImageEmbeddingProcessor.INPUT_IMAGE;
1016
import static org.opensearch.neuralsearch.processor.TextImageEmbeddingProcessor.INPUT_TEXT;
@@ -19,7 +25,6 @@
1925
import org.apache.commons.lang.builder.EqualsBuilder;
2026
import org.apache.commons.lang.builder.HashCodeBuilder;
2127
import org.apache.lucene.search.Query;
22-
import org.opensearch.Version;
2328
import org.opensearch.common.SetOnce;
2429
import org.opensearch.core.ParseField;
2530
import org.opensearch.core.action.ActionListener;
@@ -34,8 +39,9 @@
3439
import org.opensearch.index.query.QueryRewriteContext;
3540
import org.opensearch.index.query.QueryShardContext;
3641
import org.opensearch.knn.index.query.KNNQueryBuilder;
42+
import org.opensearch.knn.index.query.parser.MethodParametersParser;
43+
import org.opensearch.neuralsearch.common.MinClusterVersionUtil;
3744
import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor;
38-
import org.opensearch.neuralsearch.util.NeuralSearchClusterUtil;
3945

4046
import com.google.common.annotations.VisibleForTesting;
4147

@@ -69,18 +75,11 @@ public class NeuralQueryBuilder extends AbstractQueryBuilder<NeuralQueryBuilder>
6975
@VisibleForTesting
7076
static final ParseField QUERY_IMAGE_FIELD = new ParseField("query_image");
7177

72-
@VisibleForTesting
73-
static final ParseField MODEL_ID_FIELD = new ParseField("model_id");
78+
public static final ParseField MODEL_ID_FIELD = new ParseField("model_id");
7479

7580
@VisibleForTesting
7681
static final ParseField K_FIELD = new ParseField("k");
7782

78-
@VisibleForTesting
79-
static final ParseField MAX_DISTANCE_FIELD = new ParseField("max_distance");
80-
81-
@VisibleForTesting
82-
static final ParseField MIN_SCORE_FIELD = new ParseField("min_score");
83-
8483
private static final int DEFAULT_K = 10;
8584

8685
private static MLCommonsClientAccessor ML_CLIENT;
@@ -101,8 +100,7 @@ public static void initialize(MLCommonsClientAccessor mlClient) {
101100
@Setter(AccessLevel.PACKAGE)
102101
private Supplier<float[]> vectorSupplier;
103102
private QueryBuilder filter;
104-
private static final Version MINIMAL_SUPPORTED_VERSION_DEFAULT_MODEL_ID = Version.V_2_11_0;
105-
private static final Version MINIMAL_SUPPORTED_VERSION_RADIAL_SEARCH = Version.V_2_14_0;
103+
private Map<String, ?> methodParameters;
106104

107105
/**
108106
* Constructor from stream input
@@ -130,6 +128,9 @@ public NeuralQueryBuilder(StreamInput in) throws IOException {
130128
this.maxDistance = in.readOptionalFloat();
131129
this.minScore = in.readOptionalFloat();
132130
}
131+
if (isClusterOnOrAfterMinReqVersion(METHOD_PARAMS_FIELD.getPreferredName())) {
132+
this.methodParameters = MethodParametersParser.streamInput(in, MinClusterVersionUtil::isClusterOnOrAfterMinReqVersion);
133+
}
133134
}
134135

135136
@Override
@@ -152,6 +153,9 @@ protected void doWriteTo(StreamOutput out) throws IOException {
152153
out.writeOptionalFloat(this.maxDistance);
153154
out.writeOptionalFloat(this.minScore);
154155
}
156+
if (isClusterOnOrAfterMinReqVersion(METHOD_PARAMS_FIELD.getPreferredName())) {
157+
MethodParametersParser.streamOutput(out, methodParameters, MinClusterVersionUtil::isClusterOnOrAfterMinReqVersion);
158+
}
155159
}
156160

157161
@Override
@@ -174,6 +178,9 @@ protected void doXContent(XContentBuilder xContentBuilder, Params params) throws
174178
if (Objects.nonNull(minScore)) {
175179
xContentBuilder.field(MIN_SCORE_FIELD.getPreferredName(), minScore);
176180
}
181+
if (Objects.nonNull(methodParameters)) {
182+
MethodParametersParser.doXContent(xContentBuilder, methodParameters);
183+
}
177184
printBoostAndQueryName(xContentBuilder);
178185
xContentBuilder.endObject();
179186
xContentBuilder.endObject();
@@ -267,6 +274,8 @@ private static void parseQueryParams(XContentParser parser, NeuralQueryBuilder n
267274
} else if (token == XContentParser.Token.START_OBJECT) {
268275
if (FILTER_FIELD.match(currentFieldName, parser.getDeprecationHandler())) {
269276
neuralQueryBuilder.filter(parseInnerQueryBuilder(parser));
277+
} else if (METHOD_PARAMS_FIELD.match(currentFieldName, parser.getDeprecationHandler())) {
278+
neuralQueryBuilder.methodParameters(MethodParametersParser.fromXContent(parser));
270279
}
271280
} else {
272281
throw new ParsingException(
@@ -325,7 +334,8 @@ protected QueryBuilder doRewrite(QueryRewriteContext queryRewriteContext) {
325334
maxDistance(),
326335
minScore(),
327336
vectorSetOnce::get,
328-
filter()
337+
filter(),
338+
methodParameters()
329339
);
330340
}
331341

@@ -358,14 +368,6 @@ public String getWriteableName() {
358368
return NAME;
359369
}
360370

361-
private static boolean isClusterOnOrAfterMinReqVersionForDefaultModelIdSupport() {
362-
return NeuralSearchClusterUtil.instance().getClusterMinVersion().onOrAfter(MINIMAL_SUPPORTED_VERSION_DEFAULT_MODEL_ID);
363-
}
364-
365-
private static boolean isClusterOnOrAfterMinReqVersionForRadialSearch() {
366-
return NeuralSearchClusterUtil.instance().getClusterMinVersion().onOrAfter(MINIMAL_SUPPORTED_VERSION_RADIAL_SEARCH);
367-
}
368-
369371
private static boolean validateKNNQueryType(NeuralQueryBuilder neuralQueryBuilder) {
370372
int queryCount = 0;
371373
if (neuralQueryBuilder.k() != null) {

src/test/java/org/opensearch/neuralsearch/processor/NormalizationProcessorIT.java

+3
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,7 @@ public void testResultProcessor_whenOneShardAndQueryMatches_thenSuccessful() {
9696
null,
9797
null,
9898
null,
99+
null,
99100
null
100101
);
101102
TermQueryBuilder termQueryBuilder = QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT3);
@@ -146,6 +147,7 @@ public void testResultProcessor_whenDefaultProcessorConfigAndQueryMatches_thenSu
146147
null,
147148
null,
148149
null,
150+
null,
149151
null
150152
);
151153
TermQueryBuilder termQueryBuilder = QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT3);
@@ -185,6 +187,7 @@ public void testQueryMatches_whenMultipleShards_thenSuccessful() {
185187
null,
186188
null,
187189
null,
190+
null,
188191
null
189192
);
190193
TermQueryBuilder termQueryBuilder = QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT3);

src/test/java/org/opensearch/neuralsearch/processor/ScoreCombinationIT.java

+4-4
Original file line numberDiff line numberDiff line change
@@ -224,7 +224,7 @@ public void testHarmonicMeanCombination_whenOneShardAndQueryMatches_thenSuccessf
224224

225225
HybridQueryBuilder hybridQueryBuilderDefaultNorm = new HybridQueryBuilder();
226226
hybridQueryBuilderDefaultNorm.add(
227-
new NeuralQueryBuilder(TEST_KNN_VECTOR_FIELD_NAME_1, TEST_DOC_TEXT1, "", modelId, 5, null, null, null, null)
227+
new NeuralQueryBuilder(TEST_KNN_VECTOR_FIELD_NAME_1, TEST_DOC_TEXT1, "", modelId, 5, null, null, null, null, null)
228228
);
229229
hybridQueryBuilderDefaultNorm.add(QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT3));
230230

@@ -249,7 +249,7 @@ public void testHarmonicMeanCombination_whenOneShardAndQueryMatches_thenSuccessf
249249

250250
HybridQueryBuilder hybridQueryBuilderL2Norm = new HybridQueryBuilder();
251251
hybridQueryBuilderL2Norm.add(
252-
new NeuralQueryBuilder(TEST_KNN_VECTOR_FIELD_NAME_1, TEST_DOC_TEXT1, "", modelId, 5, null, null, null, null)
252+
new NeuralQueryBuilder(TEST_KNN_VECTOR_FIELD_NAME_1, TEST_DOC_TEXT1, "", modelId, 5, null, null, null, null, null)
253253
);
254254
hybridQueryBuilderL2Norm.add(QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT3));
255255

@@ -299,7 +299,7 @@ public void testGeometricMeanCombination_whenOneShardAndQueryMatches_thenSuccess
299299

300300
HybridQueryBuilder hybridQueryBuilderDefaultNorm = new HybridQueryBuilder();
301301
hybridQueryBuilderDefaultNorm.add(
302-
new NeuralQueryBuilder(TEST_KNN_VECTOR_FIELD_NAME_1, TEST_DOC_TEXT1, "", modelId, 5, null, null, null, null)
302+
new NeuralQueryBuilder(TEST_KNN_VECTOR_FIELD_NAME_1, TEST_DOC_TEXT1, "", modelId, 5, null, null, null, null, null)
303303
);
304304
hybridQueryBuilderDefaultNorm.add(QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT3));
305305

@@ -324,7 +324,7 @@ public void testGeometricMeanCombination_whenOneShardAndQueryMatches_thenSuccess
324324

325325
HybridQueryBuilder hybridQueryBuilderL2Norm = new HybridQueryBuilder();
326326
hybridQueryBuilderL2Norm.add(
327-
new NeuralQueryBuilder(TEST_KNN_VECTOR_FIELD_NAME_1, TEST_DOC_TEXT1, "", modelId, 5, null, null, null, null)
327+
new NeuralQueryBuilder(TEST_KNN_VECTOR_FIELD_NAME_1, TEST_DOC_TEXT1, "", modelId, 5, null, null, null, null, null)
328328
);
329329
hybridQueryBuilderL2Norm.add(QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT3));
330330

0 commit comments

Comments
 (0)