Skip to content

Commit

Permalink
Add filter function to KNNQueryBuilder with unit tests and integratio…
Browse files Browse the repository at this point in the history
…n tests

Signed-off-by: Chloe Gao <chloewq@amazon.com>
  • Loading branch information
chloewqg committed Mar 10, 2025
1 parent cc6b4f4 commit 5485bea
Show file tree
Hide file tree
Showing 4 changed files with 94 additions and 0 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
* [Remote Vector Index Build] Implement data download and IndexOutput write functionality [#2554](https://github.com/opensearch-project/k-NN/pull/2554)
* [Remote Vector Index Build] Introduce Client Skeleton + basic Build Request implementation [#2560](https://github.com/opensearch-project/k-NN/pull/2560)
* Add concurrency optimizations with native memory graph loading and force eviction (#2265) [https://github.com/opensearch-project/k-NN/pull/2345]
* Add filter function to KNNQueryBuilder with unit tests and integration tests. [#2585](https://github.com/opensearch-project/k-NN/pull/2585)
### Enhancements
* Introduce node level circuit breakers for k-NN [#2509](https://github.com/opensearch-project/k-NN/pull/2509)
### Bug Fixes
Expand Down
18 changes: 18 additions & 0 deletions src/main/java/org/opensearch/knn/index/query/KNNQueryBuilder.java
Original file line number Diff line number Diff line change
Expand Up @@ -381,6 +381,24 @@ public void doXContent(XContentBuilder builder, Params params) throws IOExceptio
KNNQueryBuilderParser.toXContent(builder, params, this);
}

/**
* Add a filter to Neural Query Builder
* @param filterToBeAdded filter to be added
* @return return itself with underlying filter combined with passed in filter
*/
@Override
public QueryBuilder filter(QueryBuilder filterToBeAdded) {
if (validateFilterParams(filterToBeAdded) == false) {
return this;
}
if (filter == null) {
filter = filterToBeAdded;
return this;
}
filter = filter.filter(filterToBeAdded);
return this;
}

@Override
protected Query doToQuery(QueryShardContext context) {
MappedFieldType mappedFieldType = context.fieldMapper(this.fieldName);
Expand Down
50 changes: 50 additions & 0 deletions src/test/java/org/opensearch/knn/index/FaissIT.java
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,10 @@
import org.opensearch.common.settings.Settings;
import org.opensearch.client.ResponseException;
import org.opensearch.core.xcontent.XContentBuilder;
import org.opensearch.index.query.MatchNoneQueryBuilder;
import org.opensearch.index.query.QueryBuilder;
import org.opensearch.index.query.QueryBuilders;
import org.opensearch.index.query.RangeQueryBuilder;
import org.opensearch.index.query.TermQueryBuilder;
import org.opensearch.knn.KNNRestTestCase;
import org.opensearch.common.xcontent.XContentFactory;
Expand Down Expand Up @@ -451,6 +453,54 @@ public void testQueryWithFilterMultipleShards() {
assertEquals(1, knnResults.size());
}

@SneakyThrows
public void testQueryWithFilterFunctionAppliedMultipleShards() {
XContentBuilder builder = XContentFactory.jsonBuilder()
.startObject()
.startObject(PROPERTIES_FIELD_NAME)
.startObject(FIELD_NAME)
.field(TYPE_FIELD_NAME, KNN_VECTOR_TYPE)
.field(DIMENSION_FIELD_NAME, "3")
.startObject(KNNConstants.KNN_METHOD)
.field(KNNConstants.NAME, METHOD_HNSW)
.field(KNNConstants.METHOD_PARAMETER_SPACE_TYPE, SpaceType.L2.getValue())
.field(KNNConstants.KNN_ENGINE, KNNEngine.FAISS.getName())
.endObject()
.endObject()
.startObject(INTEGER_FIELD_NAME)
.field(TYPE_FIELD_NAME, FILED_TYPE_INTEGER)
.endObject()
.endObject()
.endObject();
String mapping = builder.toString();
createIndex(INDEX_NAME, Settings.builder().put("number_of_shards", 10).put("number_of_replicas", 0).put("index.knn", true).build());
putMappingRequest(INDEX_NAME, mapping);

addKnnDocWithAttributes("doc1", new float[] { 7.0f, 7.0f, 3.0f }, ImmutableMap.of("dateReceived", "2024-10-01"));

refreshIndex(INDEX_NAME);

final float[] searchVector = { 6.0f, 7.0f, 3.0f };

// Add initial RangeQuery to a new KNNQueryBuilder
RangeQueryBuilder rangeQueryBuilder = QueryBuilders.rangeQuery("dateReceived").gte("2023-11-01");
KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder(FIELD_NAME, searchVector, 1, null);
knnQueryBuilder.filter(rangeQueryBuilder);
Response response = searchKNNIndex(INDEX_NAME, knnQueryBuilder, 10);
String responseBody = EntityUtils.toString(response.getEntity());
List<KNNResult> knnResults = parseSearchResponse(responseBody, FIELD_NAME);

assertEquals(1, knnResults.size());

// Apply another MatchNoneQueryBuilder to new KNNQueryBuilder
knnQueryBuilder.filter(new MatchNoneQueryBuilder());
response = searchKNNIndex(INDEX_NAME, knnQueryBuilder, 10);
responseBody = EntityUtils.toString(response.getEntity());
knnResults = parseSearchResponse(responseBody, FIELD_NAME);

assertEquals(0, knnResults.size());
}

@SneakyThrows
public void testEndToEnd_whenMethodIsHNSWPQ_thenSucceed() {
String indexName = "test-index";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import org.opensearch.core.index.Index;
import org.opensearch.index.IndexSettings;
import org.opensearch.index.mapper.NumberFieldMapper;
import org.opensearch.index.query.BoolQueryBuilder;
import org.opensearch.index.query.QueryBuilder;
import org.opensearch.index.query.QueryBuilders;
import org.opensearch.index.query.QueryRewriteContext;
Expand Down Expand Up @@ -1077,4 +1078,28 @@ public void testDoRewrite_whenFilterSet_thenSuccessful() {
// Then
assertEquals(expected, actual);
}

@SneakyThrows
public void testFilter() {
// Test for Null Case
KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder(FIELD_NAME, QUERY_VECTOR, K);
QueryBuilder updatedKnnQueryBuilder = knnQueryBuilder.filter(null);
assertEquals(knnQueryBuilder, updatedKnnQueryBuilder);

// Test for valid case
knnQueryBuilder = KNNQueryBuilder.builder().fieldName(FIELD_NAME).vector(QUERY_VECTOR).filter(TERM_QUERY).k(K).build();
updatedKnnQueryBuilder = knnQueryBuilder.filter(TERM_QUERY);
BoolQueryBuilder expectedUpdatedQueryFilter = new BoolQueryBuilder();
expectedUpdatedQueryFilter.must(TERM_QUERY);
expectedUpdatedQueryFilter.filter(TERM_QUERY);
assertEquals(knnQueryBuilder, updatedKnnQueryBuilder);
assertEquals(expectedUpdatedQueryFilter, knnQueryBuilder.getFilter());

// Test for queryBuilder without filter initialized where filter function would
// simply assign filter to its filter field.
knnQueryBuilder = KNNQueryBuilder.builder().fieldName(FIELD_NAME).vector(QUERY_VECTOR).k(K).build();
updatedKnnQueryBuilder = knnQueryBuilder.filter(TERM_QUERY);
assertEquals(knnQueryBuilder, updatedKnnQueryBuilder);
assertEquals(TERM_QUERY, knnQueryBuilder.getFilter());
}
}

0 comments on commit 5485bea

Please sign in to comment.