Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add filter function to KNNQueryBuilder #2585

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
* [Remote Vector Index Build] Implement data download and IndexOutput write functionality [#2554](https://github.com/opensearch-project/k-NN/pull/2554)
* [Remote Vector Index Build] Introduce Client Skeleton + basic Build Request implementation [#2560](https://github.com/opensearch-project/k-NN/pull/2560)
* Add concurrency optimizations with native memory graph loading and force eviction (#2265) [https://github.com/opensearch-project/k-NN/pull/2345]
* Add filter function to KNNQueryBuilder with unit tests and integration tests. [#2585](https://github.com/opensearch-project/k-NN/pull/2585)
### Enhancements
* Introduce node level circuit breakers for k-NN [#2509](https://github.com/opensearch-project/k-NN/pull/2509)
### Bug Fixes
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -381,6 +381,24 @@ public void doXContent(XContentBuilder builder, Params params) throws IOExceptio
KNNQueryBuilderParser.toXContent(builder, params, this);
}

/**
* Add a filter to Neural Query Builder
* @param filterToBeAdded filter to be added
* @return return itself with underlying filter combined with passed in filter
*/
@Override
public QueryBuilder filter(QueryBuilder filterToBeAdded) {
if (validateFilterParams(filterToBeAdded) == false) {
return this;
}
if (filter == null) {
filter = filterToBeAdded;
return this;
}
filter = filter.filter(filterToBeAdded);
return this;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please refer #2585 (comment) and go through the builder here. Do not manipulate the existing filter, it increases the chances of side effects and bugs

return KNNQueryBuilder.builder().filter(newFilter)....build()//copy everything from existing

}

@Override
protected Query doToQuery(QueryShardContext context) {
MappedFieldType mappedFieldType = context.fieldMapper(this.fieldName);
Expand Down
50 changes: 50 additions & 0 deletions src/test/java/org/opensearch/knn/index/FaissIT.java
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,10 @@
import org.opensearch.common.settings.Settings;
import org.opensearch.client.ResponseException;
import org.opensearch.core.xcontent.XContentBuilder;
import org.opensearch.index.query.MatchNoneQueryBuilder;
import org.opensearch.index.query.QueryBuilder;
import org.opensearch.index.query.QueryBuilders;
import org.opensearch.index.query.RangeQueryBuilder;
import org.opensearch.index.query.TermQueryBuilder;
import org.opensearch.knn.KNNRestTestCase;
import org.opensearch.common.xcontent.XContentFactory;
Expand Down Expand Up @@ -451,6 +453,54 @@ public void testQueryWithFilterMultipleShards() {
assertEquals(1, knnResults.size());
}

@SneakyThrows
public void testQueryWithFilterFunctionAppliedMultipleShards() {
XContentBuilder builder = XContentFactory.jsonBuilder()
.startObject()
.startObject(PROPERTIES_FIELD_NAME)
.startObject(FIELD_NAME)
.field(TYPE_FIELD_NAME, KNN_VECTOR_TYPE)
.field(DIMENSION_FIELD_NAME, "3")
.startObject(KNNConstants.KNN_METHOD)
.field(KNNConstants.NAME, METHOD_HNSW)
.field(KNNConstants.METHOD_PARAMETER_SPACE_TYPE, SpaceType.L2.getValue())
.field(KNNConstants.KNN_ENGINE, KNNEngine.FAISS.getName())
.endObject()
.endObject()
.startObject(INTEGER_FIELD_NAME)
.field(TYPE_FIELD_NAME, FILED_TYPE_INTEGER)
.endObject()
.endObject()
.endObject();
String mapping = builder.toString();
createIndex(INDEX_NAME, Settings.builder().put("number_of_shards", 10).put("number_of_replicas", 0).put("index.knn", true).build());
putMappingRequest(INDEX_NAME, mapping);

addKnnDocWithAttributes("doc1", new float[] { 7.0f, 7.0f, 3.0f }, ImmutableMap.of("dateReceived", "2024-10-01"));

refreshIndex(INDEX_NAME);

final float[] searchVector = { 6.0f, 7.0f, 3.0f };

// Add initial RangeQuery to a new KNNQueryBuilder
RangeQueryBuilder rangeQueryBuilder = QueryBuilders.rangeQuery("dateReceived").gte("2023-11-01");
KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder(FIELD_NAME, searchVector, 1, null);
knnQueryBuilder.filter(rangeQueryBuilder);
Response response = searchKNNIndex(INDEX_NAME, knnQueryBuilder, 10);
String responseBody = EntityUtils.toString(response.getEntity());
List<KNNResult> knnResults = parseSearchResponse(responseBody, FIELD_NAME);

assertEquals(1, knnResults.size());

// Apply another MatchNoneQueryBuilder to new KNNQueryBuilder
knnQueryBuilder.filter(new MatchNoneQueryBuilder());
response = searchKNNIndex(INDEX_NAME, knnQueryBuilder, 10);
responseBody = EntityUtils.toString(response.getEntity());
knnResults = parseSearchResponse(responseBody, FIELD_NAME);

assertEquals(0, knnResults.size());
}

@SneakyThrows
public void testEndToEnd_whenMethodIsHNSWPQ_thenSucceed() {
String indexName = "test-index";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import org.opensearch.core.index.Index;
import org.opensearch.index.IndexSettings;
import org.opensearch.index.mapper.NumberFieldMapper;
import org.opensearch.index.query.BoolQueryBuilder;
import org.opensearch.index.query.QueryBuilder;
import org.opensearch.index.query.QueryBuilders;
import org.opensearch.index.query.QueryRewriteContext;
Expand Down Expand Up @@ -1077,4 +1078,28 @@ public void testDoRewrite_whenFilterSet_thenSuccessful() {
// Then
assertEquals(expected, actual);
}

@SneakyThrows
public void testFilter() {
// Test for Null Case
KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder(FIELD_NAME, QUERY_VECTOR, K);
QueryBuilder updatedKnnQueryBuilder = knnQueryBuilder.filter(null);
assertEquals(knnQueryBuilder, updatedKnnQueryBuilder);

// Test for valid case
knnQueryBuilder = KNNQueryBuilder.builder().fieldName(FIELD_NAME).vector(QUERY_VECTOR).filter(TERM_QUERY).k(K).build();
updatedKnnQueryBuilder = knnQueryBuilder.filter(TERM_QUERY);
BoolQueryBuilder expectedUpdatedQueryFilter = new BoolQueryBuilder();
expectedUpdatedQueryFilter.must(TERM_QUERY);
expectedUpdatedQueryFilter.filter(TERM_QUERY);
assertEquals(knnQueryBuilder, updatedKnnQueryBuilder);
assertEquals(expectedUpdatedQueryFilter, knnQueryBuilder.getFilter());

// Test for queryBuilder without filter initialized where filter function would
// simply assign filter to its filter field.
knnQueryBuilder = KNNQueryBuilder.builder().fieldName(FIELD_NAME).vector(QUERY_VECTOR).k(K).build();
updatedKnnQueryBuilder = knnQueryBuilder.filter(TERM_QUERY);
assertEquals(knnQueryBuilder, updatedKnnQueryBuilder);
assertEquals(TERM_QUERY, knnQueryBuilder.getFilter());
}
}
Loading