diff --git a/jni/src/faiss_wrapper.cpp b/jni/src/faiss_wrapper.cpp index 25baa43fde..831078da2a 100644 --- a/jni/src/faiss_wrapper.cpp +++ b/jni/src/faiss_wrapper.cpp @@ -318,7 +318,6 @@ jobjectArray knn_jni::faiss_wrapper::QueryIndex_WithFilter(knn_jni::JNIUtilInter if (methodParamsJ != nullptr) { methodParams = jniUtil->ConvertJavaMapToCppMap(env, methodParamsJ); } - // The ids vector will hold the top k ids from the search and the dis vector will hold the top k distances from // the query point std::vector dis(kJ); @@ -357,7 +356,10 @@ jobjectArray knn_jni::faiss_wrapper::QueryIndex_WithFilter(knn_jni::JNIUtilInter } else { auto ivfReader = dynamic_cast(indexReader->index); auto ivfFlatReader = dynamic_cast(indexReader->index); + if(ivfReader || ivfFlatReader) { + int indexNprobe = ivfReader == nullptr ? ivfReader->nprobe : ivfFlatReader->nprobe; + ivfParams.nprobe = commons::getIntegerMethodParameter(env, jniUtil, methodParams, NPROBES, indexNprobe); ivfParams.sel = idSelector.get(); searchParameters = &ivfParams; } @@ -373,10 +375,11 @@ jobjectArray knn_jni::faiss_wrapper::QueryIndex_WithFilter(knn_jni::JNIUtilInter } else { faiss::SearchParameters *searchParameters = nullptr; faiss::SearchParametersHNSW hnswParams; + faiss::SearchParametersIVF ivfParams; std::unique_ptr idGrouper; std::vector idGrouperBitmap; auto hnswReader = dynamic_cast(indexReader->index); - if(hnswReader!= nullptr) { + if(hnswReader != nullptr) { // Query param efsearch supersedes ef_search provided during index setting. hnswParams.efSearch = knn_jni::commons::getIntegerMethodParameter(env, jniUtil, methodParams, EF_SEARCH, hnswReader->hnsw.efSearch); if (parentIdsJ != nullptr) { @@ -384,6 +387,13 @@ jobjectArray knn_jni::faiss_wrapper::QueryIndex_WithFilter(knn_jni::JNIUtilInter hnswParams.grp = idGrouper.get(); } searchParameters = &hnswParams; + } else { + auto ivfReader = dynamic_cast(indexReader->index); + if (ivfReader) { + int indexNprobe = ivfReader->nprobe; + ivfParams.nprobe = commons::getIntegerMethodParameter(env, jniUtil, methodParams, NPROBES, indexNprobe); + searchParameters = &ivfParams; + } } try { indexReader->search(1, rawQueryvector, kJ, dis.data(), ids.data(), searchParameters); diff --git a/jni/tests/faiss_wrapper_unit_test.cpp b/jni/tests/faiss_wrapper_unit_test.cpp index d9fdac23fb..3911346edf 100644 --- a/jni/tests/faiss_wrapper_unit_test.cpp +++ b/jni/tests/faiss_wrapper_unit_test.cpp @@ -12,6 +12,8 @@ #include "faiss_wrapper.h" #include +#include +#include #include "gmock/gmock.h" #include "gtest/gtest.h" @@ -30,6 +32,46 @@ struct MockIndex : faiss::IndexHNSW { } }; +struct MockIVFIndex : faiss::IndexIVFFlat { + explicit MockIVFIndex() = default; +}; + +struct MockIVFIdMap : faiss::IndexIDMap { + mutable idx_t nCalled{}; + mutable const float *xCalled{}; + mutable idx_t kCalled{}; + mutable float *distancesCalled{}; + mutable idx_t *labelsCalled{}; + mutable const faiss::SearchParametersIVF *paramsCalled{}; + + explicit MockIVFIdMap(MockIVFIndex *index) : faiss::IndexIDMapTemplate(index) { + } + + void search( + idx_t n, + const float *x, + idx_t k, + float *distances, + idx_t *labels, + const faiss::SearchParameters *params) const override { + nCalled = n; + xCalled = x; + kCalled = k; + distancesCalled = distances; + labelsCalled = labels; + paramsCalled = dynamic_cast(params); + } + + void resetMock() const { + nCalled = 0; + xCalled = nullptr; + kCalled = 0; + distancesCalled = nullptr; + labelsCalled = nullptr; + paramsCalled = nullptr; + } +}; + struct MockIdMap : faiss::IndexIDMap { mutable idx_t nCalled{}; mutable const float *xCalled{}; @@ -83,13 +125,14 @@ struct MockIdMap : faiss::IndexIDMap { } }; -struct QueryIndexHNSWTestInput { - std::string description; +struct QueryIndexInput { + string description; int k; - int efSearch; int filterIdType; bool filterIdsPresent; bool parentIdsPresent; + int efSearch; + int nprobe; }; struct RangeSearchTestInput { @@ -101,9 +144,9 @@ struct RangeSearchTestInput { bool parentIdsPresent; }; -class FaissWrappeterParametrizedTestFixture : public testing::TestWithParam { +class FaissWrapperParametrizedTestFixture : public testing::TestWithParam { public: - FaissWrappeterParametrizedTestFixture() : index_(3), id_map_(&index_) { + FaissWrapperParametrizedTestFixture() : index_(3), id_map_(&index_) { index_.hnsw.efSearch = 100; // assigning 100 to make sure default of 16 is not used anywhere } @@ -123,16 +166,25 @@ class FaissWrapperParametrizedRangeSearchTestFixture : public testing::TestWithP MockIdMap id_map_; }; -namespace query_index_test { +class FaissWrapperIVFQueryTestFixture : public testing::TestWithParam { +public: + FaissWrapperIVFQueryTestFixture() : ivf_id_map_(&ivf_index_) { + ivf_index_.nprobe = 100; + }; - std::unordered_map methodParams; +protected: + MockIVFIndex ivf_index_; + MockIVFIdMap ivf_id_map_; +}; - TEST_P(FaissWrappeterParametrizedTestFixture, QueryIndexHNSWTests) { - // Given +namespace query_index_test { + + TEST_P(FaissWrapperParameterizedTestFixture, QueryIndexHNSWTests) { + //Given JNIEnv *jniEnv = nullptr; NiceMock mockJNIUtil; - QueryIndexHNSWTestInput const &input = GetParam(); + QueryIndexInput const &input = GetParam(); float query[] = {1.2, 2.3, 3.4}; int efSearch = input.efSearch; @@ -184,24 +236,23 @@ namespace query_index_test { INSTANTIATE_TEST_CASE_P( QueryIndexHNSWTests, - FaissWrappeterParametrizedTestFixture, + FaissWrapperParametrizedTestFixture, ::testing::Values( - QueryIndexHNSWTestInput{"algoParams present, parent absent", 10, 200, 0, false, false}, - QueryIndexHNSWTestInput{"algoParams absent, parent absent", 10, -1, 0, false, false}, - QueryIndexHNSWTestInput{"algoParams present, parent present", 10, 200, 0, false, true}, - QueryIndexHNSWTestInput{"algoParams absent, parent present", 10, -1, 0, false, true} + QueryIndexInput {"algoParams present, parent absent", 10, 0, false, false, 200, -1 }, + QueryIndexInput {"algoParams present, parent absent", 10, 0, false, false, -1, -1 }, + QueryIndexInput {"algoParams present, parent present", 10, 0, false, true, 200, -1 }, + QueryIndexInput {"algoParams absent, parent present", 10, 0, false, true, -1, -1 } ) ); } namespace query_index_with_filter_test { - - TEST_P(FaissWrappeterParametrizedTestFixture, QueryIndexWithFilterHNSWTests) { - // Given + TEST_P(FaissWrapperParameterizedTestFixture, QueryIndexWithFilterHNSWTests) { + //Given JNIEnv *jniEnv = nullptr; NiceMock mockJNIUtil; - QueryIndexHNSWTestInput const &input = GetParam(); + QueryIndexInput const &input = GetParam(); float query[] = {1.2, 2.3, 3.4}; std::vector *parentIdPtr = nullptr; @@ -267,23 +318,23 @@ namespace query_index_with_filter_test { INSTANTIATE_TEST_CASE_P( QueryIndexWithFilterHNSWTests, - FaissWrappeterParametrizedTestFixture, + FaissWrapperParametrizedTestFixture, ::testing::Values( - QueryIndexHNSWTestInput{"algoParams present, parent absent, filter absent", 10, 200, 0, false, false}, - QueryIndexHNSWTestInput{"algoParams present, parent absent, filter absent, filter type 1", 10, 200, 1, false, false}, - QueryIndexHNSWTestInput{"algoParams absent, parent absent, filter present", 10, -1, 0, true, false}, - QueryIndexHNSWTestInput{"algoParams absent, parent absent, filter present, filter type 1", 10, -1, 1, true, false}, - QueryIndexHNSWTestInput{"algoParams present, parent present, filter absent", 10, 200, 0, false, true}, - QueryIndexHNSWTestInput{"algoParams present, parent present, filter absent, filter type 1", 10, 150, 1, false, true}, - QueryIndexHNSWTestInput{"algoParams absent, parent present, filter present", 10, -1, 0, true, true}, - QueryIndexHNSWTestInput{"algoParams absent, parent present, filter present, filter type 1",10, -1, 1, true, true} + QueryIndexInput { "algoParams present, parent absent, filter absent", 10, 0, false, false, 200, -1 }, + QueryIndexInput { "algoParams present, parent absent, filter absent, filter type 1", 10, 1, false, false, 200, -1}, + QueryIndexInput { "algoParams absent, parent absent, filter present", 10, 0, true, false, -1, -1}, + QueryIndexInput { "algoParams absent, parent absent, filter present, filter type 1", 10, 1, true, false, -1, -1}, + QueryIndexInput { "algoParams present, parent present, filter absent", 10, 0, false, true, 200, -1 }, + QueryIndexInput { "algoParams present, parent present, filter absent, filter type 1", 10, 1, false, true, 150, -1}, + QueryIndexInput { "algoParams absent, parent present, filter present", 10, 0, true, true, -1, -1}, + QueryIndexInput { "algoParams absent, parent present, filter present, filter type 1",10, 1, true, true, -1, -1 } ) ); } namespace range_search_test { - TEST_P(FaissWrapperParametrizedRangeSearchTestFixture, RangeSearchHNSWTests) { + TEST_P(FaissWrapperParameterizedRangeSearchTestFixture, RangeSearchHNSWTests) { // Given JNIEnv *jniEnv = nullptr; NiceMock mockJNIUtil; @@ -323,6 +374,7 @@ namespace range_search_test { std::vector filter; std::vector *filterptr = nullptr; if (input.filterIdsPresent) { + std::vector filter; filter.reserve(2); filter.push_back(1); filter.push_back(2); @@ -356,7 +408,7 @@ namespace range_search_test { INSTANTIATE_TEST_CASE_P( RangeSearchHNSWTests, - FaissWrapperParametrizedRangeSearchTestFixture, + FaissWrapperParameterizedRangeSearchTestFixture, ::testing::Values( RangeSearchTestInput{"algoParams present, parent absent, filter absent", 10.0f, 200, 0, false, false}, RangeSearchTestInput{"algoParams present, parent absent, filter absent, filter type 1", 10.0f, 200, 1, false, false}, @@ -370,3 +422,65 @@ namespace range_search_test { ); } +namespace query_index_with_filter_test_ivf { + + TEST_P(FaissWrapperIVFQueryTestFixture, QueryIndexIVFTest) { + //Given + JNIEnv *jniEnv = nullptr; + NiceMock mockJNIUtil; + + QueryIndexInput const &input = GetParam(); + float query[] = {1.2, 2.3, 3.4}; + + int nprobe = input.nprobe; + int expectedNprobe = 100; //default set in mock + std::unordered_map methodParams; + if (nprobe != -1) { + expectedNprobe = input.nprobe; + methodParams[knn_jni::NPROBES] = reinterpret_cast(&nprobe); + } + + std::vector *filterptr = nullptr; + if (input.filterIdsPresent) { + std::vector filter; + std::vector *filterptr = nullptr; + if (input.filterIdsPresent) { + std::vector filter; + filter.reserve(2); + filter.push_back(1); + filter.push_back(2); + filterptr = &filter; + } + } + // When + knn_jni::faiss_wrapper::QueryIndex_WithFilter( + &mockJNIUtil, jniEnv, + reinterpret_cast(&ivf_id_map_), + reinterpret_cast(&query), input.k, reinterpret_cast(&methodParams), + reinterpret_cast(filterptr), + input.filterIdType, + nullptr); + + //Then + int actualEfSearch = ivf_id_map_.paramsCalled->nprobe; + // Asserting the captured argument + EXPECT_EQ(input.k, ivf_id_map_.kCalled); + EXPECT_EQ(expectedNprobe, actualEfSearch); + if (input.parentIdsPresent) { + faiss::IDGrouper *grouper = ivf_id_map_.paramsCalled->grp; + EXPECT_TRUE(grouper != nullptr); + } + ivf_id_map_.resetMock(); + } + + INSTANTIATE_TEST_CASE_P( + QueryIndexIVFTest, + FaissWrapperIVFQueryTestFixture, + ::testing::Values( + QueryIndexInput{"algoParams present, parent absent", 10, 0, false, false, -1, 200 }, + QueryIndexInput{"algoParams present, parent absent", 10,0, false, false, -1, -1 }, + QueryIndexInput{"algoParams present, parent present", 10, 0, true, true, -1, 200 }, + QueryIndexInput{"algoParams absent, parent present", 10, 0, true, true, -1, -1 } + ) + ); +} diff --git a/src/main/java/org/opensearch/knn/index/query/KNNQueryBuilder.java b/src/main/java/org/opensearch/knn/index/query/KNNQueryBuilder.java index 8b511e58e7..4daa4c2ab7 100644 --- a/src/main/java/org/opensearch/knn/index/query/KNNQueryBuilder.java +++ b/src/main/java/org/opensearch/knn/index/query/KNNQueryBuilder.java @@ -27,7 +27,6 @@ import org.opensearch.index.query.QueryShardContext; import org.opensearch.knn.common.KNNConstants; import org.opensearch.knn.index.IndexUtil; -import org.opensearch.knn.index.util.EngineSpecificMethodContext; import org.opensearch.knn.index.KNNMethodContext; import org.opensearch.knn.index.MethodComponentContext; import org.opensearch.knn.index.SpaceType; @@ -35,6 +34,7 @@ import org.opensearch.knn.index.VectorQueryType; import org.opensearch.knn.index.mapper.KNNVectorFieldMapper; import org.opensearch.knn.index.query.parser.MethodParametersParser; +import org.opensearch.knn.index.util.EngineSpecificMethodContext; import org.opensearch.knn.index.util.KNNEngine; import org.opensearch.knn.indices.ModelDao; import org.opensearch.knn.indices.ModelMetadata; @@ -49,6 +49,7 @@ import static org.opensearch.knn.common.KNNConstants.MAX_DISTANCE; import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER; import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_EF_SEARCH; +import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_NPROBES; import static org.opensearch.knn.common.KNNConstants.MIN_SCORE; import static org.opensearch.knn.common.KNNValidationUtil.validateByteVectorValue; import static org.opensearch.knn.index.IndexUtil.isClusterOnOrAfterMinRequiredVersion; @@ -72,6 +73,7 @@ public class KNNQueryBuilder extends AbstractQueryBuilder { public static final ParseField MAX_DISTANCE_FIELD = new ParseField(MAX_DISTANCE); public static final ParseField MIN_SCORE_FIELD = new ParseField(MIN_SCORE); public static final ParseField EF_SEARCH_FIELD = new ParseField(METHOD_PARAMETER_EF_SEARCH); + public static final ParseField NPROBE_FIELD = new ParseField(METHOD_PARAMETER_NPROBES); public static final ParseField METHOD_PARAMS_FIELD = new ParseField(METHOD_PARAMETER); public static final int K_MAX = 10000; /** diff --git a/src/main/java/org/opensearch/knn/index/query/parser/MethodParametersParser.java b/src/main/java/org/opensearch/knn/index/query/parser/MethodParametersParser.java index e2ba8f26e9..41b69f4414 100644 --- a/src/main/java/org/opensearch/knn/index/query/parser/MethodParametersParser.java +++ b/src/main/java/org/opensearch/knn/index/query/parser/MethodParametersParser.java @@ -32,6 +32,9 @@ import static org.opensearch.knn.index.query.KNNQueryBuilder.METHOD_PARAMS_FIELD; import static org.opensearch.knn.index.query.KNNQueryBuilder.NAME; +/** + * Note: This parser is used by neural plugin as well, breaking changes will require changes in neural as well + */ @EqualsAndHashCode @Getter @AllArgsConstructor diff --git a/src/main/java/org/opensearch/knn/index/query/request/MethodParameter.java b/src/main/java/org/opensearch/knn/index/query/request/MethodParameter.java index 5c3aa51822..2039e810ce 100644 --- a/src/main/java/org/opensearch/knn/index/query/request/MethodParameter.java +++ b/src/main/java/org/opensearch/knn/index/query/request/MethodParameter.java @@ -21,7 +21,9 @@ import java.util.Map; import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_EF_SEARCH; +import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_NPROBES; import static org.opensearch.knn.index.query.KNNQueryBuilder.EF_SEARCH_FIELD; +import static org.opensearch.knn.index.query.KNNQueryBuilder.NPROBE_FIELD; /** * MethodParameters are engine and algorithm related parameters that clients can pass in knn query @@ -35,11 +37,7 @@ public enum MethodParameter { EF_SEARCH(METHOD_PARAMETER_EF_SEARCH, Version.CURRENT, EF_SEARCH_FIELD) { @Override public Integer parse(Object value) { - try { - return Integer.parseInt(String.valueOf(value)); - } catch (final NumberFormatException e) { - throw new IllegalArgumentException(METHOD_PARAMETER_EF_SEARCH + " value must be an integer"); - } + return parseInteger(value, METHOD_PARAMETER_EF_SEARCH); } @Override @@ -48,11 +46,31 @@ public ValidationException validate(Object value) { if (ef != null && ef > 0) { return null; } - ; + ValidationException validationException = new ValidationException(); validationException.addValidationError(METHOD_PARAMETER_EF_SEARCH + " should be greater than 0"); return validationException; } + }, + + // TODO: change the version to 2.16 when merging into 2.x + NPROBE(METHOD_PARAMETER_NPROBES, Version.CURRENT, NPROBE_FIELD) { + @Override + public Integer parse(Object value) { + return parseInteger(value, METHOD_PARAMETER_EF_SEARCH); + } + + @Override + public ValidationException validate(Object value) { + final Integer nprobe = parse(value); + if (nprobe != null && nprobe > 0) { + return null; + } + + ValidationException validationException = new ValidationException(); + validationException.addValidationError(METHOD_PARAMETER_NPROBES + " should be greater than 0"); + return validationException; + } }; private final String name; @@ -75,4 +93,12 @@ public static MethodParameter enumOf(final String name) { } return PARAMETERS_DIR.get(name); } + + private static Integer parseInteger(Object value, String name) { + try { + return Integer.parseInt(String.valueOf(value)); + } catch (final NumberFormatException e) { + throw new IllegalArgumentException(name + " value must be an integer"); + } + } } diff --git a/src/main/java/org/opensearch/knn/index/util/DefaultIVFContext.java b/src/main/java/org/opensearch/knn/index/util/DefaultIVFContext.java new file mode 100644 index 0000000000..2179bba460 --- /dev/null +++ b/src/main/java/org/opensearch/knn/index/util/DefaultIVFContext.java @@ -0,0 +1,30 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.knn.index.util; + +import com.google.common.collect.ImmutableMap; +import org.opensearch.knn.index.Parameter; +import org.opensearch.knn.index.query.request.MethodParameter; + +import java.util.Map; + +public final class DefaultIVFContext implements EngineSpecificMethodContext { + + private final Map> supportedMethodParameters = ImmutableMap.>builder() + .put(MethodParameter.NPROBE.getName(), new Parameter.IntegerParameter(MethodParameter.NPROBE.getName(), null, value -> true)) + .build(); + + @Override + public Map> supportedMethodParameters() { + return supportedMethodParameters; + } +} diff --git a/src/main/java/org/opensearch/knn/index/util/Faiss.java b/src/main/java/org/opensearch/knn/index/util/Faiss.java index 7cf31ba3c1..263abac32b 100644 --- a/src/main/java/org/opensearch/knn/index/util/Faiss.java +++ b/src/main/java/org/opensearch/knn/index/util/Faiss.java @@ -332,7 +332,7 @@ private Faiss( ) { super( methods, - Map.of(METHOD_HNSW, new DefaultHnswContext(), METHOD_IVF, EngineSpecificMethodContext.EMPTY), + Map.of(METHOD_HNSW, new DefaultHnswContext(), METHOD_IVF, new DefaultIVFContext()), scoreTranslation, currentVersion, extension diff --git a/src/test/java/org/opensearch/knn/index/FaissIT.java b/src/test/java/org/opensearch/knn/index/FaissIT.java index b94aad82d3..63f7290b18 100644 --- a/src/test/java/org/opensearch/knn/index/FaissIT.java +++ b/src/test/java/org/opensearch/knn/index/FaissIT.java @@ -617,6 +617,7 @@ public void testIVFSQFP16_whenIndexedAndQueried_thenSucceed() { indexTestData(indexName, fieldName, dimension, numDocs); queryTestData(indexName, fieldName, dimension, numDocs); + queryTestData(indexName, fieldName, dimension, numDocs, Map.of("nprobes", 100)); deleteKNNIndex(indexName); validateGraphEviction(); } @@ -1623,13 +1624,23 @@ protected void setupKNNIndexForFilterQuery() throws Exception { refreshIndex(INDEX_NAME); } - private void queryTestData(final String indexName, final String fieldName, final int dimension, final int numDocs) throws IOException, - ParseException { + @SneakyThrows + private void queryTestData(final String indexName, final String fieldName, final int dimension, final int numDocs) { + queryTestData(indexName, fieldName, dimension, numDocs, null); + } + + private void queryTestData( + final String indexName, + final String fieldName, + final int dimension, + final int numDocs, + Map methodParams + ) throws IOException, ParseException { float[] queryVector = new float[dimension]; Arrays.fill(queryVector, (float) numDocs); int k = 10; - Response searchResponse = searchKNNIndex(indexName, new KNNQueryBuilder(fieldName, queryVector, k), k); + Response searchResponse = searchKNNIndex(indexName, buildSearchQuery(fieldName, k, queryVector, methodParams), k); List results = parseSearchResponse(EntityUtils.toString(searchResponse.getEntity()), fieldName); assertEquals(k, results.size()); for (int i = 0; i < k; i++) { diff --git a/src/test/java/org/opensearch/knn/index/LuceneEngineIT.java b/src/test/java/org/opensearch/knn/index/LuceneEngineIT.java index 874e96ba52..9e205fc6ed 100644 --- a/src/test/java/org/opensearch/knn/index/LuceneEngineIT.java +++ b/src/test/java/org/opensearch/knn/index/LuceneEngineIT.java @@ -510,7 +510,7 @@ private void validateQueries(SpaceType spaceType, String fieldName, Map knnResults = parseSearchResponse(responseBody, fieldName); assertEquals(k, knnResults.size()); @@ -525,27 +525,6 @@ private void validateQueries(SpaceType spaceType, String fieldName, Map methodParams) { - XContentBuilder builder = XContentFactory.jsonBuilder() - .startObject() - .startObject("query") - .startObject("knn") - .startObject(fieldName) - .field("vector", vector) - .field("k", k); - if (methodParams != null) { - builder.startObject("method_parameters"); - for (Map.Entry entry : methodParams.entrySet()) { - builder.field(entry.getKey(), entry.getValue()); - } - builder.endObject(); - } - - builder.endObject().endObject().endObject().endObject(); - return builder; - } - private List queryResults(final float[] searchVector, final int k) throws Exception { final String responseBody = EntityUtils.toString( searchKNNIndex(INDEX_NAME, new KNNQueryBuilder(FIELD_NAME, searchVector, k), k).getEntity() diff --git a/src/test/java/org/opensearch/knn/index/query/KNNQueryBuilderTests.java b/src/test/java/org/opensearch/knn/index/query/KNNQueryBuilderTests.java index a04eb6f0ff..2f4fd0506f 100644 --- a/src/test/java/org/opensearch/knn/index/query/KNNQueryBuilderTests.java +++ b/src/test/java/org/opensearch/knn/index/query/KNNQueryBuilderTests.java @@ -833,7 +833,6 @@ public void testDoToQuery_WhenknnQueryWithFilterAndFaissEngine_thenSuccess() { assertEquals(HNSW_METHOD_PARAMS, ((KNNQuery) query).getMethodParameters()); } - /** This test should be uncommented once we have nprobs. Considering engine instance is static its not possible to test this right now public void testDoToQuery_ThrowsIllegalArgumentExceptionForUnknownMethodParameter() { QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); @@ -849,11 +848,11 @@ public void testDoToQuery_ThrowsIllegalArgumentExceptionForUnknownMethodParamete .fieldName(FIELD_NAME) .vector(queryVector) .k(K) - .methodParameters(Map.of("ef_search", 10)) + .methodParameters(Map.of("nprobes", 10)) .build(); expectThrows(IllegalArgumentException.class, () -> knnQueryBuilder.doToQuery(mockQueryShardContext)); - }**/ + } public void testDoToQuery_whenknnQueryWithFilterAndNmsLibEngine_thenException() { float[] queryVector = { 1.0f, 2.0f, 3.0f, 4.0f }; diff --git a/src/testFixtures/java/org/opensearch/knn/KNNRestTestCase.java b/src/testFixtures/java/org/opensearch/knn/KNNRestTestCase.java index 860cd2efaa..92739a5e00 100644 --- a/src/testFixtures/java/org/opensearch/knn/KNNRestTestCase.java +++ b/src/testFixtures/java/org/opensearch/knn/KNNRestTestCase.java @@ -1609,4 +1609,25 @@ protected void addKnnDocWithAttributes( Response response = client().performRequest(request); assertEquals(request.getEndpoint() + ": failed", RestStatus.OK, RestStatus.fromCode(response.getStatusLine().getStatusCode())); } + + @SneakyThrows + protected XContentBuilder buildSearchQuery(String fieldName, int k, float[] vector, Map methodParams) { + XContentBuilder builder = XContentFactory.jsonBuilder() + .startObject() + .startObject("query") + .startObject("knn") + .startObject(fieldName) + .field("vector", vector) + .field("k", k); + if (methodParams != null) { + builder.startObject("method_parameters"); + for (Map.Entry entry : methodParams.entrySet()) { + builder.field(entry.getKey(), entry.getValue()); + } + builder.endObject(); + } + + builder.endObject().endObject().endObject().endObject(); + return builder; + } }