Skip to content

Commit 15bc8d0

Browse files
committed
Support of new k-NN query parameter expand_nested.
Signed-off-by: Bo Zhang <bzhangam@amazon.com>
1 parent 393d49a commit 15bc8d0

File tree

14 files changed

+204
-17
lines changed

14 files changed

+204
-17
lines changed

CHANGELOG.md

+1
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
1818
### Features
1919
### Enhancements
2020
- Explainability in hybrid query ([#970](https://github.com/opensearch-project/neural-search/pull/970))
21+
- Support new knn query parameter expand_nested ([#1013](https://github.com/opensearch-project/neural-search/pull/1013))
2122
### Bug Fixes
2223
- Address inconsistent scoring in hybrid query results ([#998](https://github.com/opensearch-project/neural-search/pull/998))
2324
### Infrastructure

qa/restart-upgrade/src/test/java/org/opensearch/neuralsearch/bwc/HybridSearchIT.java

+11-3
Original file line numberDiff line numberDiff line change
@@ -71,9 +71,9 @@ private void validateNormalizationProcessor(final String fileName, final String
7171
modelId = getModelId(getIngestionPipeline(pipelineName), TEXT_EMBEDDING_PROCESSOR);
7272
loadModel(modelId);
7373
addDocuments(getIndexNameForTest(), false);
74-
HybridQueryBuilder hybridQueryBuilder = getQueryBuilder(modelId, null, null);
74+
HybridQueryBuilder hybridQueryBuilder = getQueryBuilder(modelId, null, null, null);
7575
validateTestIndex(getIndexNameForTest(), searchPipelineName, hybridQueryBuilder);
76-
hybridQueryBuilder = getQueryBuilder(modelId, Map.of("ef_search", 100), RescoreContext.getDefault());
76+
hybridQueryBuilder = getQueryBuilder(modelId, Boolean.TRUE, Map.of("ef_search", 100), RescoreContext.getDefault());
7777
validateTestIndex(getIndexNameForTest(), searchPipelineName, hybridQueryBuilder);
7878
} finally {
7979
wipeOfTestResources(getIndexNameForTest(), pipelineName, modelId, searchPipelineName);
@@ -115,12 +115,20 @@ private void validateTestIndex(final String index, final String searchPipeline,
115115
}
116116
}
117117

118-
private HybridQueryBuilder getQueryBuilder(final String modelId, Map<String, ?> methodParameters, RescoreContext rescoreContext) {
118+
private HybridQueryBuilder getQueryBuilder(
119+
final String modelId,
120+
final Boolean expandNestedDocs,
121+
final Map<String, ?> methodParameters,
122+
final RescoreContext rescoreContext
123+
) {
119124
NeuralQueryBuilder neuralQueryBuilder = new NeuralQueryBuilder();
120125
neuralQueryBuilder.fieldName("passage_embedding");
121126
neuralQueryBuilder.modelId(modelId);
122127
neuralQueryBuilder.queryText(QUERY);
123128
neuralQueryBuilder.k(5);
129+
if (expandNestedDocs != null) {
130+
neuralQueryBuilder.expandNested(expandNestedDocs);
131+
}
124132
if (methodParameters != null) {
125133
neuralQueryBuilder.methodParameters(methodParameters);
126134
}

qa/restart-upgrade/src/test/java/org/opensearch/neuralsearch/bwc/KnnRadialSearchIT.java

+2
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@ private void validateIndexQuery(final String modelId) {
6262
null,
6363
null,
6464
null,
65+
null,
6566
null
6667
);
6768
Map<String, Object> responseWithMinScoreQuery = search(getIndexNameForTest(), neuralQueryBuilderWithMinScoreQuery, 1);
@@ -78,6 +79,7 @@ private void validateIndexQuery(final String modelId) {
7879
null,
7980
null,
8081
null,
82+
null,
8183
null
8284
);
8385
Map<String, Object> responseWithMaxDistanceQuery = search(getIndexNameForTest(), neuralQueryBuilderWithMaxDistanceQuery, 1);

qa/restart-upgrade/src/test/java/org/opensearch/neuralsearch/bwc/MultiModalSearchIT.java

+1
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@ private void validateTestIndex(final String modelId) throws Exception {
6464
null,
6565
null,
6666
null,
67+
null,
6768
null
6869
);
6970
Map<String, Object> response = search(getIndexNameForTest(), neuralQueryBuilder, 1);

qa/rolling-upgrade/src/test/java/org/opensearch/neuralsearch/bwc/HybridSearchIT.java

+8-4
Original file line numberDiff line numberDiff line change
@@ -67,13 +67,13 @@ public void testNormalizationProcessor_whenIndexWithMultipleShards_E2EFlow() thr
6767
int totalDocsCountMixed;
6868
if (isFirstMixedRound()) {
6969
totalDocsCountMixed = NUM_DOCS_PER_ROUND;
70-
HybridQueryBuilder hybridQueryBuilder = getQueryBuilder(modelId, null, null);
70+
HybridQueryBuilder hybridQueryBuilder = getQueryBuilder(modelId, null, null, null);
7171
QueryBuilder rescorer = QueryBuilders.matchQuery(TEST_FIELD, RESCORE_QUERY).boost(0.3f);
7272
validateTestIndexOnUpgrade(totalDocsCountMixed, modelId, hybridQueryBuilder, rescorer);
7373
addDocument(getIndexNameForTest(), "1", TEST_FIELD, TEXT_MIXED, null, null);
7474
} else {
7575
totalDocsCountMixed = 2 * NUM_DOCS_PER_ROUND;
76-
HybridQueryBuilder hybridQueryBuilder = getQueryBuilder(modelId, null, null);
76+
HybridQueryBuilder hybridQueryBuilder = getQueryBuilder(modelId, null, null, null);
7777
validateTestIndexOnUpgrade(totalDocsCountMixed, modelId, hybridQueryBuilder, null);
7878
}
7979
break;
@@ -83,10 +83,10 @@ public void testNormalizationProcessor_whenIndexWithMultipleShards_E2EFlow() thr
8383
int totalDocsCountUpgraded = 3 * NUM_DOCS_PER_ROUND;
8484
loadModel(modelId);
8585
addDocument(getIndexNameForTest(), "2", TEST_FIELD, TEXT_UPGRADED, null, null);
86-
HybridQueryBuilder hybridQueryBuilder = getQueryBuilder(modelId, null, null);
86+
HybridQueryBuilder hybridQueryBuilder = getQueryBuilder(modelId, null, null, null);
8787
QueryBuilder rescorer = QueryBuilders.matchQuery(TEST_FIELD, RESCORE_QUERY).boost(0.3f);
8888
validateTestIndexOnUpgrade(totalDocsCountUpgraded, modelId, hybridQueryBuilder, rescorer);
89-
hybridQueryBuilder = getQueryBuilder(modelId, Map.of("ef_search", 100), RescoreContext.getDefault());
89+
hybridQueryBuilder = getQueryBuilder(modelId, Boolean.TRUE, Map.of("ef_search", 100), RescoreContext.getDefault());
9090
validateTestIndexOnUpgrade(totalDocsCountUpgraded, modelId, hybridQueryBuilder, rescorer);
9191
} finally {
9292
wipeOfTestResources(getIndexNameForTest(), PIPELINE_NAME, modelId, SEARCH_PIPELINE_NAME);
@@ -124,6 +124,7 @@ private void validateTestIndexOnUpgrade(
124124

125125
private HybridQueryBuilder getQueryBuilder(
126126
final String modelId,
127+
final Boolean expandNestedDocs,
127128
final Map<String, ?> methodParameters,
128129
final RescoreContext rescoreContextForNeuralQuery
129130
) {
@@ -132,6 +133,9 @@ private HybridQueryBuilder getQueryBuilder(
132133
neuralQueryBuilder.modelId(modelId);
133134
neuralQueryBuilder.queryText(QUERY);
134135
neuralQueryBuilder.k(5);
136+
if (expandNestedDocs != null) {
137+
neuralQueryBuilder.expandNested(expandNestedDocs);
138+
}
135139
if (methodParameters != null) {
136140
neuralQueryBuilder.methodParameters(methodParameters);
137141
}

qa/rolling-upgrade/src/test/java/org/opensearch/neuralsearch/bwc/KnnRadialSearchIT.java

+2
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,7 @@ private void validateIndexQueryOnUpgrade(final int numberOfDocs, final String mo
8888
null,
8989
null,
9090
null,
91+
null,
9192
null
9293
);
9394
Map<String, Object> responseWithMinScore = search(getIndexNameForTest(), neuralQueryBuilderWithMinScoreQuery, 1);
@@ -104,6 +105,7 @@ private void validateIndexQueryOnUpgrade(final int numberOfDocs, final String mo
104105
null,
105106
null,
106107
null,
108+
null,
107109
null
108110
);
109111
Map<String, Object> responseWithMaxScore = search(getIndexNameForTest(), neuralQueryBuilderWithMaxDistanceQuery, 1);

qa/rolling-upgrade/src/test/java/org/opensearch/neuralsearch/bwc/MultiModalSearchIT.java

+1
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,7 @@ private void validateTestIndexOnUpgrade(final int numberOfDocs, final String mod
8787
null,
8888
null,
8989
null,
90+
null,
9091
null
9192
);
9293
Map<String, Object> responseWithKQuery = search(getIndexNameForTest(), neuralQueryBuilderWithKQuery, 1);

src/main/java/org/opensearch/neuralsearch/query/NeuralQueryBuilder.java

+15
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
*/
55
package org.opensearch.neuralsearch.query;
66

7+
import static org.opensearch.knn.index.query.KNNQueryBuilder.EXPAND_NESTED_FIELD;
78
import static org.opensearch.knn.index.query.KNNQueryBuilder.FILTER_FIELD;
89
import static org.opensearch.knn.index.query.KNNQueryBuilder.MAX_DISTANCE_FIELD;
910
import static org.opensearch.knn.index.query.KNNQueryBuilder.METHOD_PARAMS_FIELD;
@@ -98,6 +99,7 @@ public static void initialize(MLCommonsClientAccessor mlClient) {
9899
private Integer k = null;
99100
private Float maxDistance = null;
100101
private Float minScore = null;
102+
private Boolean expandNested;
101103
@VisibleForTesting
102104
@Getter(AccessLevel.PACKAGE)
103105
@Setter(AccessLevel.PACKAGE)
@@ -132,6 +134,9 @@ public NeuralQueryBuilder(StreamInput in) throws IOException {
132134
this.maxDistance = in.readOptionalFloat();
133135
this.minScore = in.readOptionalFloat();
134136
}
137+
if (isClusterOnOrAfterMinReqVersion(EXPAND_NESTED_FIELD.getPreferredName())) {
138+
this.expandNested = in.readOptionalBoolean();
139+
}
135140
if (isClusterOnOrAfterMinReqVersion(METHOD_PARAMS_FIELD.getPreferredName())) {
136141
this.methodParameters = MethodParametersParser.streamInput(in, MinClusterVersionUtil::isClusterOnOrAfterMinReqVersion);
137142
}
@@ -158,6 +163,9 @@ protected void doWriteTo(StreamOutput out) throws IOException {
158163
out.writeOptionalFloat(this.maxDistance);
159164
out.writeOptionalFloat(this.minScore);
160165
}
166+
if (isClusterOnOrAfterMinReqVersion(EXPAND_NESTED_FIELD.getPreferredName())) {
167+
out.writeOptionalBoolean(this.expandNested);
168+
}
161169
if (isClusterOnOrAfterMinReqVersion(METHOD_PARAMS_FIELD.getPreferredName())) {
162170
MethodParametersParser.streamOutput(out, methodParameters, MinClusterVersionUtil::isClusterOnOrAfterMinReqVersion);
163171
}
@@ -184,6 +192,9 @@ protected void doXContent(XContentBuilder xContentBuilder, Params params) throws
184192
if (Objects.nonNull(minScore)) {
185193
xContentBuilder.field(MIN_SCORE_FIELD.getPreferredName(), minScore);
186194
}
195+
if (Objects.nonNull(expandNested)) {
196+
xContentBuilder.field(EXPAND_NESTED_FIELD.getPreferredName(), expandNested);
197+
}
187198
if (Objects.nonNull(methodParameters)) {
188199
MethodParametersParser.doXContent(xContentBuilder, methodParameters);
189200
}
@@ -274,6 +285,8 @@ private static void parseQueryParams(XContentParser parser, NeuralQueryBuilder n
274285
neuralQueryBuilder.maxDistance(parser.floatValue());
275286
} else if (MIN_SCORE_FIELD.match(currentFieldName, parser.getDeprecationHandler())) {
276287
neuralQueryBuilder.minScore(parser.floatValue());
288+
} else if (EXPAND_NESTED_FIELD.match(currentFieldName, parser.getDeprecationHandler())) {
289+
neuralQueryBuilder.expandNested(parser.booleanValue());
277290
} else {
278291
throw new ParsingException(
279292
parser.getTokenLocation(),
@@ -318,6 +331,7 @@ protected QueryBuilder doRewrite(QueryRewriteContext queryRewriteContext) {
318331
.filter(filter())
319332
.maxDistance(maxDistance)
320333
.minScore(minScore)
334+
.expandNested(expandNested)
321335
.k(k)
322336
.methodParameters(methodParameters)
323337
.rescoreContext(rescoreContext)
@@ -346,6 +360,7 @@ protected QueryBuilder doRewrite(QueryRewriteContext queryRewriteContext) {
346360
k(),
347361
maxDistance(),
348362
minScore(),
363+
expandNested(),
349364
vectorSetOnce::get,
350365
filter(),
351366
methodParameters(),

src/test/java/org/opensearch/neuralsearch/processor/NormalizationProcessorIT.java

+3
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,7 @@ public void testResultProcessor_whenOneShardAndQueryMatches_thenSuccessful() {
9898
null,
9999
null,
100100
null,
101+
null,
101102
null
102103
);
103104
TermQueryBuilder termQueryBuilder = QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT3);
@@ -150,6 +151,7 @@ public void testResultProcessor_whenDefaultProcessorConfigAndQueryMatches_thenSu
150151
null,
151152
null,
152153
null,
154+
null,
153155
null
154156
);
155157
TermQueryBuilder termQueryBuilder = QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT3);
@@ -191,6 +193,7 @@ public void testQueryMatches_whenMultipleShards_thenSuccessful() {
191193
null,
192194
null,
193195
null,
196+
null,
194197
null
195198
);
196199
TermQueryBuilder termQueryBuilder = QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT3);

src/test/java/org/opensearch/neuralsearch/processor/ScoreCombinationIT.java

+56-4
Original file line numberDiff line numberDiff line change
@@ -224,7 +224,20 @@ public void testHarmonicMeanCombination_whenOneShardAndQueryMatches_thenSuccessf
224224

225225
HybridQueryBuilder hybridQueryBuilderDefaultNorm = new HybridQueryBuilder();
226226
hybridQueryBuilderDefaultNorm.add(
227-
new NeuralQueryBuilder(TEST_KNN_VECTOR_FIELD_NAME_1, TEST_DOC_TEXT1, "", modelId, 5, null, null, null, null, null, null)
227+
new NeuralQueryBuilder(
228+
TEST_KNN_VECTOR_FIELD_NAME_1,
229+
TEST_DOC_TEXT1,
230+
"",
231+
modelId,
232+
5,
233+
null,
234+
null,
235+
null,
236+
null,
237+
null,
238+
null,
239+
null
240+
)
228241
);
229242
hybridQueryBuilderDefaultNorm.add(QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT3));
230243

@@ -249,7 +262,20 @@ public void testHarmonicMeanCombination_whenOneShardAndQueryMatches_thenSuccessf
249262

250263
HybridQueryBuilder hybridQueryBuilderL2Norm = new HybridQueryBuilder();
251264
hybridQueryBuilderL2Norm.add(
252-
new NeuralQueryBuilder(TEST_KNN_VECTOR_FIELD_NAME_1, TEST_DOC_TEXT1, "", modelId, 5, null, null, null, null, null, null)
265+
new NeuralQueryBuilder(
266+
TEST_KNN_VECTOR_FIELD_NAME_1,
267+
TEST_DOC_TEXT1,
268+
"",
269+
modelId,
270+
5,
271+
null,
272+
null,
273+
null,
274+
null,
275+
null,
276+
null,
277+
null
278+
)
253279
);
254280
hybridQueryBuilderL2Norm.add(QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT3));
255281

@@ -299,7 +325,20 @@ public void testGeometricMeanCombination_whenOneShardAndQueryMatches_thenSuccess
299325

300326
HybridQueryBuilder hybridQueryBuilderDefaultNorm = new HybridQueryBuilder();
301327
hybridQueryBuilderDefaultNorm.add(
302-
new NeuralQueryBuilder(TEST_KNN_VECTOR_FIELD_NAME_1, TEST_DOC_TEXT1, "", modelId, 5, null, null, null, null, null, null)
328+
new NeuralQueryBuilder(
329+
TEST_KNN_VECTOR_FIELD_NAME_1,
330+
TEST_DOC_TEXT1,
331+
"",
332+
modelId,
333+
5,
334+
null,
335+
null,
336+
null,
337+
null,
338+
null,
339+
null,
340+
null
341+
)
303342
);
304343
hybridQueryBuilderDefaultNorm.add(QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT3));
305344

@@ -324,7 +363,20 @@ public void testGeometricMeanCombination_whenOneShardAndQueryMatches_thenSuccess
324363

325364
HybridQueryBuilder hybridQueryBuilderL2Norm = new HybridQueryBuilder();
326365
hybridQueryBuilderL2Norm.add(
327-
new NeuralQueryBuilder(TEST_KNN_VECTOR_FIELD_NAME_1, TEST_DOC_TEXT1, "", modelId, 5, null, null, null, null, null, null)
366+
new NeuralQueryBuilder(
367+
TEST_KNN_VECTOR_FIELD_NAME_1,
368+
TEST_DOC_TEXT1,
369+
"",
370+
modelId,
371+
5,
372+
null,
373+
null,
374+
null,
375+
null,
376+
null,
377+
null,
378+
null
379+
)
328380
);
329381
hybridQueryBuilderL2Norm.add(QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT3));
330382

0 commit comments

Comments
 (0)