From 5485bea3f390c15348c3912efe9eacb21dd3c20b Mon Sep 17 00:00:00 2001 From: Chloe Gao Date: Wed, 5 Mar 2025 22:32:26 -0800 Subject: [PATCH] Add filter function to KNNQueryBuilder with unit tests and integration tests Signed-off-by: Chloe Gao --- CHANGELOG.md | 1 + .../knn/index/query/KNNQueryBuilder.java | 18 +++++++ .../org/opensearch/knn/index/FaissIT.java | 50 +++++++++++++++++++ .../knn/index/query/KNNQueryBuilderTests.java | 25 ++++++++++ 4 files changed, 94 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index ca52ba8324..84c95399ae 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,6 +11,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), * [Remote Vector Index Build] Implement data download and IndexOutput write functionality [#2554](https://github.com/opensearch-project/k-NN/pull/2554) * [Remote Vector Index Build] Introduce Client Skeleton + basic Build Request implementation [#2560](https://github.com/opensearch-project/k-NN/pull/2560) * Add concurrency optimizations with native memory graph loading and force eviction (#2265) [https://github.com/opensearch-project/k-NN/pull/2345] +* Add filter function to KNNQueryBuilder with unit tests and integration tests. [#2585](https://github.com/opensearch-project/k-NN/pull/2585) ### Enhancements * Introduce node level circuit breakers for k-NN [#2509](https://github.com/opensearch-project/k-NN/pull/2509) ### Bug Fixes diff --git a/src/main/java/org/opensearch/knn/index/query/KNNQueryBuilder.java b/src/main/java/org/opensearch/knn/index/query/KNNQueryBuilder.java index f032210aab..4bcc34348e 100644 --- a/src/main/java/org/opensearch/knn/index/query/KNNQueryBuilder.java +++ b/src/main/java/org/opensearch/knn/index/query/KNNQueryBuilder.java @@ -381,6 +381,24 @@ public void doXContent(XContentBuilder builder, Params params) throws IOExceptio KNNQueryBuilderParser.toXContent(builder, params, this); } + /** + * Add a filter to Neural Query Builder + * @param filterToBeAdded filter to be added + * @return return itself with underlying filter combined with passed in filter + */ + @Override + public QueryBuilder filter(QueryBuilder filterToBeAdded) { + if (validateFilterParams(filterToBeAdded) == false) { + return this; + } + if (filter == null) { + filter = filterToBeAdded; + return this; + } + filter = filter.filter(filterToBeAdded); + return this; + } + @Override protected Query doToQuery(QueryShardContext context) { MappedFieldType mappedFieldType = context.fieldMapper(this.fieldName); diff --git a/src/test/java/org/opensearch/knn/index/FaissIT.java b/src/test/java/org/opensearch/knn/index/FaissIT.java index b40b36eb40..70fc01fb65 100644 --- a/src/test/java/org/opensearch/knn/index/FaissIT.java +++ b/src/test/java/org/opensearch/knn/index/FaissIT.java @@ -23,8 +23,10 @@ import org.opensearch.common.settings.Settings; import org.opensearch.client.ResponseException; import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.index.query.MatchNoneQueryBuilder; import org.opensearch.index.query.QueryBuilder; import org.opensearch.index.query.QueryBuilders; +import org.opensearch.index.query.RangeQueryBuilder; import org.opensearch.index.query.TermQueryBuilder; import org.opensearch.knn.KNNRestTestCase; import org.opensearch.common.xcontent.XContentFactory; @@ -451,6 +453,54 @@ public void testQueryWithFilterMultipleShards() { assertEquals(1, knnResults.size()); } + @SneakyThrows + public void testQueryWithFilterFunctionAppliedMultipleShards() { + XContentBuilder builder = XContentFactory.jsonBuilder() + .startObject() + .startObject(PROPERTIES_FIELD_NAME) + .startObject(FIELD_NAME) + .field(TYPE_FIELD_NAME, KNN_VECTOR_TYPE) + .field(DIMENSION_FIELD_NAME, "3") + .startObject(KNNConstants.KNN_METHOD) + .field(KNNConstants.NAME, METHOD_HNSW) + .field(KNNConstants.METHOD_PARAMETER_SPACE_TYPE, SpaceType.L2.getValue()) + .field(KNNConstants.KNN_ENGINE, KNNEngine.FAISS.getName()) + .endObject() + .endObject() + .startObject(INTEGER_FIELD_NAME) + .field(TYPE_FIELD_NAME, FILED_TYPE_INTEGER) + .endObject() + .endObject() + .endObject(); + String mapping = builder.toString(); + createIndex(INDEX_NAME, Settings.builder().put("number_of_shards", 10).put("number_of_replicas", 0).put("index.knn", true).build()); + putMappingRequest(INDEX_NAME, mapping); + + addKnnDocWithAttributes("doc1", new float[] { 7.0f, 7.0f, 3.0f }, ImmutableMap.of("dateReceived", "2024-10-01")); + + refreshIndex(INDEX_NAME); + + final float[] searchVector = { 6.0f, 7.0f, 3.0f }; + + // Add initial RangeQuery to a new KNNQueryBuilder + RangeQueryBuilder rangeQueryBuilder = QueryBuilders.rangeQuery("dateReceived").gte("2023-11-01"); + KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder(FIELD_NAME, searchVector, 1, null); + knnQueryBuilder.filter(rangeQueryBuilder); + Response response = searchKNNIndex(INDEX_NAME, knnQueryBuilder, 10); + String responseBody = EntityUtils.toString(response.getEntity()); + List knnResults = parseSearchResponse(responseBody, FIELD_NAME); + + assertEquals(1, knnResults.size()); + + // Apply another MatchNoneQueryBuilder to new KNNQueryBuilder + knnQueryBuilder.filter(new MatchNoneQueryBuilder()); + response = searchKNNIndex(INDEX_NAME, knnQueryBuilder, 10); + responseBody = EntityUtils.toString(response.getEntity()); + knnResults = parseSearchResponse(responseBody, FIELD_NAME); + + assertEquals(0, knnResults.size()); + } + @SneakyThrows public void testEndToEnd_whenMethodIsHNSWPQ_thenSucceed() { String indexName = "test-index"; diff --git a/src/test/java/org/opensearch/knn/index/query/KNNQueryBuilderTests.java b/src/test/java/org/opensearch/knn/index/query/KNNQueryBuilderTests.java index 1333d616e6..cb9f4a8645 100644 --- a/src/test/java/org/opensearch/knn/index/query/KNNQueryBuilderTests.java +++ b/src/test/java/org/opensearch/knn/index/query/KNNQueryBuilderTests.java @@ -22,6 +22,7 @@ import org.opensearch.core.index.Index; import org.opensearch.index.IndexSettings; import org.opensearch.index.mapper.NumberFieldMapper; +import org.opensearch.index.query.BoolQueryBuilder; import org.opensearch.index.query.QueryBuilder; import org.opensearch.index.query.QueryBuilders; import org.opensearch.index.query.QueryRewriteContext; @@ -1077,4 +1078,28 @@ public void testDoRewrite_whenFilterSet_thenSuccessful() { // Then assertEquals(expected, actual); } + + @SneakyThrows + public void testFilter() { + // Test for Null Case + KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder(FIELD_NAME, QUERY_VECTOR, K); + QueryBuilder updatedKnnQueryBuilder = knnQueryBuilder.filter(null); + assertEquals(knnQueryBuilder, updatedKnnQueryBuilder); + + // Test for valid case + knnQueryBuilder = KNNQueryBuilder.builder().fieldName(FIELD_NAME).vector(QUERY_VECTOR).filter(TERM_QUERY).k(K).build(); + updatedKnnQueryBuilder = knnQueryBuilder.filter(TERM_QUERY); + BoolQueryBuilder expectedUpdatedQueryFilter = new BoolQueryBuilder(); + expectedUpdatedQueryFilter.must(TERM_QUERY); + expectedUpdatedQueryFilter.filter(TERM_QUERY); + assertEquals(knnQueryBuilder, updatedKnnQueryBuilder); + assertEquals(expectedUpdatedQueryFilter, knnQueryBuilder.getFilter()); + + // Test for queryBuilder without filter initialized where filter function would + // simply assign filter to its filter field. + knnQueryBuilder = KNNQueryBuilder.builder().fieldName(FIELD_NAME).vector(QUERY_VECTOR).k(K).build(); + updatedKnnQueryBuilder = knnQueryBuilder.filter(TERM_QUERY); + assertEquals(knnQueryBuilder, updatedKnnQueryBuilder); + assertEquals(TERM_QUERY, knnQueryBuilder.getFilter()); + } }