From 9174a9f64a61e04c2a9784480b5be18651aea589 Mon Sep 17 00:00:00 2001 From: Naveen Tatikonda Date: Mon, 23 Jan 2023 19:37:50 -0600 Subject: [PATCH] Add Unit and Integration tests for Clear Cache API Signed-off-by: Naveen Tatikonda --- .../opensearch/knn/KNNSingleNodeTestCase.java | 30 +++++ .../knn/index/KNNIndexShardTests.java | 24 ++++ .../action/RestClearCacheHandlerIT.java | 85 +++++++++++++ .../ClearCacheTransportActionTests.java | 116 ++++++++++++++++++ .../org/opensearch/knn/KNNRestTestCase.java | 15 +++ 5 files changed, 270 insertions(+) create mode 100644 src/test/java/org/opensearch/knn/plugin/action/RestClearCacheHandlerIT.java create mode 100644 src/test/java/org/opensearch/knn/plugin/transport/ClearCacheTransportActionTests.java diff --git a/src/test/java/org/opensearch/knn/KNNSingleNodeTestCase.java b/src/test/java/org/opensearch/knn/KNNSingleNodeTestCase.java index c83433eca5..08ee7976b8 100644 --- a/src/test/java/org/opensearch/knn/KNNSingleNodeTestCase.java +++ b/src/test/java/org/opensearch/knn/KNNSingleNodeTestCase.java @@ -5,6 +5,12 @@ package org.opensearch.knn; +import org.opensearch.cluster.ClusterName; +import org.opensearch.cluster.ClusterState; +import org.opensearch.cluster.block.ClusterBlock; +import org.opensearch.cluster.block.ClusterBlockLevel; +import org.opensearch.cluster.block.ClusterBlocks; +import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.bytes.BytesReference; import org.opensearch.common.xcontent.XContentHelper; import org.opensearch.knn.index.query.KNNQueryBuilder; @@ -32,9 +38,12 @@ import java.io.IOException; import java.util.Collection; import java.util.Collections; +import java.util.EnumSet; import java.util.Map; import java.util.concurrent.ExecutionException; +import static org.mockito.Mockito.when; + public class KNNSingleNodeTestCase extends OpenSearchSingleNodeTestCase { @Override public void setUp() throws Exception { @@ -154,4 +163,25 @@ public void assertTrainingSucceeds(ModelDao modelDao, String modelId, int attemp fail("Training did not succeed after " + attempts + " attempts with a delay of " + delayInMillis + " ms."); } + + // Add Global Cluster Block with the given ClusterBlockLevel + protected void addGlobalClusterBlock(ClusterService clusterService, String description, EnumSet clusterBlockLevels) { + ClusterBlock block = new ClusterBlock(randomInt(), description, false, false, false, RestStatus.FORBIDDEN, clusterBlockLevels); + ClusterBlocks clusterBlocks = ClusterBlocks.builder().addGlobalBlock(block).build(); + ClusterState state = ClusterState.builder(ClusterName.DEFAULT).blocks(clusterBlocks).build(); + when(clusterService.state()).thenReturn(state); + } + + // Add Cluster Block for an Index with given ClusterBlockLevel + protected void addIndexClusterBlock( + ClusterService clusterService, + String description, + EnumSet clusterBlockLevels, + String testIndex + ) { + ClusterBlock block = new ClusterBlock(randomInt(), description, false, false, false, RestStatus.FORBIDDEN, clusterBlockLevels); + ClusterBlocks clusterBlocks = ClusterBlocks.builder().addIndexBlock(testIndex, block).build(); + ClusterState state = ClusterState.builder(ClusterName.DEFAULT).blocks(clusterBlocks).build(); + when(clusterService.state()).thenReturn(state); + } } diff --git a/src/test/java/org/opensearch/knn/index/KNNIndexShardTests.java b/src/test/java/org/opensearch/knn/index/KNNIndexShardTests.java index fc88a8ea67..f6cd5d6ac5 100644 --- a/src/test/java/org/opensearch/knn/index/KNNIndexShardTests.java +++ b/src/test/java/org/opensearch/knn/index/KNNIndexShardTests.java @@ -152,4 +152,28 @@ public void testGetEnginePaths() { assertEquals(includedFileNames.size(), included.size()); included.keySet().forEach(o -> assertTrue(includedFileNames.contains(o))); } + + public void testClearCache_emptyIndex() throws IOException { + IndexService indexService = createKNNIndex(testIndexName); + createKnnIndexMapping(testIndexName, testFieldName, dimensions); + + IndexShard indexShard = indexService.iterator().next(); + KNNIndexShard knnIndexShard = new KNNIndexShard(indexShard); + knnIndexShard.clearCache(); + assertNull(NativeMemoryCacheManager.getInstance().getIndicesCacheStats().get(testIndexName)); + } + + public void testClearCache_shardPresentInCache() throws InterruptedException, ExecutionException, IOException { + IndexService indexService = createKNNIndex(testIndexName); + createKnnIndexMapping(testIndexName, testFieldName, dimensions); + addKnnDoc(testIndexName, String.valueOf(randomInt()), testFieldName, new Float[] { randomFloat(), randomFloat() }); + + IndexShard indexShard = indexService.iterator().next(); + KNNIndexShard knnIndexShard = new KNNIndexShard(indexShard); + knnIndexShard.warmup(); + assertEquals(1, NativeMemoryCacheManager.getInstance().getIndicesCacheStats().get(testIndexName).get(GRAPH_COUNT)); + + knnIndexShard.clearCache(); + assertNull(NativeMemoryCacheManager.getInstance().getIndicesCacheStats().get(testIndexName)); + } } diff --git a/src/test/java/org/opensearch/knn/plugin/action/RestClearCacheHandlerIT.java b/src/test/java/org/opensearch/knn/plugin/action/RestClearCacheHandlerIT.java new file mode 100644 index 0000000000..cbdaff3574 --- /dev/null +++ b/src/test/java/org/opensearch/knn/plugin/action/RestClearCacheHandlerIT.java @@ -0,0 +1,85 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.plugin.action; + +import org.junit.Test; +import org.opensearch.client.Request; +import org.opensearch.client.ResponseException; +import org.opensearch.common.settings.Settings; +import org.opensearch.knn.KNNRestTestCase; +import org.opensearch.knn.plugin.KNNPlugin; +import org.opensearch.rest.RestRequest; + +import java.io.IOException; +import java.util.Arrays; +import java.util.Collections; + +import static org.opensearch.knn.common.KNNConstants.CLEAR_CACHE; + +/** + * Integration tests to validate ClearCache API + */ + +public class RestClearCacheHandlerIT extends KNNRestTestCase { + private static final String TEST_FIELD = "test-field"; + private static final int DIMENSIONS = 2; + + // @Test(expected = ResponseException.class) + public void testNonExistentIndex() throws IOException { + String nonExistentIndex = "non-existent-index"; + + String restURI = String.join("/", KNNPlugin.KNN_BASE_URI, CLEAR_CACHE, nonExistentIndex); + Request request = new Request(RestRequest.Method.GET.name(), restURI); + + ResponseException ex = expectThrows(ResponseException.class, () -> client().performRequest(request)); + assertTrue(ex.getMessage().contains(nonExistentIndex)); + } + + @Test(expected = ResponseException.class) + public void testNotKnnIndex() throws IOException { + String notKNNIndex = "not-KNN-index"; + createIndex(notKNNIndex, Settings.EMPTY); + + String restURI = String.join("/", KNNPlugin.KNN_BASE_URI, CLEAR_CACHE, notKNNIndex); + Request request = new Request(RestRequest.Method.GET.name(), restURI); + + ResponseException ex = expectThrows(ResponseException.class, () -> client().performRequest(request)); + assertTrue(ex.getMessage().contains(notKNNIndex)); + } + + public void testClearCacheSingleIndex() throws Exception { + String testIndex = getTestName().toLowerCase(); + int graphCountBefore = getTotalGraphsInCache(); + createKnnIndex(testIndex, getKNNDefaultIndexSettings(), createKnnIndexMapping(TEST_FIELD, DIMENSIONS)); + addKnnDoc(testIndex, String.valueOf(randomInt()), TEST_FIELD, new Float[] { randomFloat(), randomFloat() }); + + knnWarmup(Collections.singletonList(testIndex)); + + assertEquals(graphCountBefore + 1, getTotalGraphsInCache()); + + clearCache(Collections.singletonList(testIndex)); + assertEquals(graphCountBefore, getTotalGraphsInCache()); + } + + public void testClearCacheMultipleIndices() throws Exception { + String testIndex1 = getTestName().toLowerCase(); + String testIndex2 = getTestName().toLowerCase() + 1; + int graphCountBefore = getTotalGraphsInCache(); + + createKnnIndex(testIndex1, getKNNDefaultIndexSettings(), createKnnIndexMapping(TEST_FIELD, DIMENSIONS)); + addKnnDoc(testIndex1, String.valueOf(randomInt()), TEST_FIELD, new Float[] { randomFloat(), randomFloat() }); + + createKnnIndex(testIndex2, getKNNDefaultIndexSettings(), createKnnIndexMapping(TEST_FIELD, DIMENSIONS)); + addKnnDoc(testIndex2, String.valueOf(randomInt()), TEST_FIELD, new Float[] { randomFloat(), randomFloat() }); + + knnWarmup(Arrays.asList(testIndex1, testIndex2)); + + assertEquals(graphCountBefore + 2, getTotalGraphsInCache()); + + clearCache(Arrays.asList(testIndex1, testIndex2)); + assertEquals(graphCountBefore, getTotalGraphsInCache()); + } +} diff --git a/src/test/java/org/opensearch/knn/plugin/transport/ClearCacheTransportActionTests.java b/src/test/java/org/opensearch/knn/plugin/transport/ClearCacheTransportActionTests.java new file mode 100644 index 0000000000..1ca254f2db --- /dev/null +++ b/src/test/java/org/opensearch/knn/plugin/transport/ClearCacheTransportActionTests.java @@ -0,0 +1,116 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.plugin.transport; + +import org.opensearch.cluster.ClusterName; +import org.opensearch.cluster.ClusterState; +import org.opensearch.cluster.block.ClusterBlockException; +import org.opensearch.cluster.block.ClusterBlockLevel; +import org.opensearch.cluster.routing.ShardRouting; +import org.opensearch.cluster.routing.ShardsIterator; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.index.IndexService; +import org.opensearch.knn.KNNSingleNodeTestCase; +import org.opensearch.knn.index.memory.NativeMemoryCacheManager; + +import java.io.IOException; +import java.util.EnumSet; +import java.util.concurrent.ExecutionException; + +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +public class ClearCacheTransportActionTests extends KNNSingleNodeTestCase { + private static final String TEST_FIELD = "test-field"; + private static final int DIMENSIONS = 2; + + public void testShardOperation() throws IOException, ExecutionException, InterruptedException { + String testIndex = getTestName().toLowerCase(); + KNNWarmupRequest knnWarmupRequest = new KNNWarmupRequest(testIndex); + KNNWarmupTransportAction knnWarmupTransportAction = node().injector().getInstance(KNNWarmupTransportAction.class); + assertEquals(0, NativeMemoryCacheManager.getInstance().getIndicesCacheStats().size()); + + IndexService indexService = createKNNIndex(testIndex); + createKnnIndexMapping(testIndex, TEST_FIELD, DIMENSIONS); + addKnnDoc(testIndex, String.valueOf(randomInt()), TEST_FIELD, new Float[] { randomFloat(), randomFloat() }); + ShardRouting shardRouting = indexService.iterator().next().routingEntry(); + + knnWarmupTransportAction.shardOperation(knnWarmupRequest, shardRouting); + assertEquals(1, NativeMemoryCacheManager.getInstance().getIndicesCacheStats().size()); + + ClearCacheRequest clearCacheRequest = new ClearCacheRequest(testIndex); + ClearCacheTransportAction clearCacheTransportAction = node().injector().getInstance(ClearCacheTransportAction.class); + clearCacheTransportAction.shardOperation(clearCacheRequest, shardRouting); + assertEquals(0, NativeMemoryCacheManager.getInstance().getIndicesCacheStats().size()); + } + + public void testShards() throws InterruptedException, ExecutionException, IOException { + String testIndex = getTestName().toLowerCase(); + ClusterService clusterService = node().injector().getInstance(ClusterService.class); + ClearCacheTransportAction clearCacheTransportAction = node().injector().getInstance(ClearCacheTransportAction.class); + ClearCacheRequest clearCacheRequest = new ClearCacheRequest(testIndex); + + createKNNIndex(testIndex); + createKnnIndexMapping(testIndex, TEST_FIELD, DIMENSIONS); + addKnnDoc(testIndex, String.valueOf(randomInt()), TEST_FIELD, new Float[] { randomFloat(), randomFloat() }); + + ShardsIterator shardsIterator = clearCacheTransportAction.shards( + clusterService.state(), + clearCacheRequest, + new String[] { testIndex } + ); + assertEquals(1, shardsIterator.size()); + } + + public void testCheckGlobalBlock_throwsClusterBlockException() { + String testIndex = getTestName().toLowerCase(); + String description = "testing metadata block"; + ClusterService clusterService = mock(ClusterService.class); + addGlobalClusterBlock(clusterService, description, EnumSet.of(ClusterBlockLevel.METADATA_READ)); + ClearCacheTransportAction clearCacheTransportAction = node().injector().getInstance(ClearCacheTransportAction.class); + ClearCacheRequest clearCacheRequest = new ClearCacheRequest(testIndex); + ClusterBlockException ex = clearCacheTransportAction.checkGlobalBlock(clusterService.state(), clearCacheRequest); + assertTrue(ex.getMessage().contains(description)); + } + + public void testCheckGlobalBlock_notThrowsClusterBlockException() { + String testIndex = getTestName().toLowerCase(); + ClusterService clusterService = mock(ClusterService.class); + ClearCacheTransportAction clearCacheTransportAction = node().injector().getInstance(ClearCacheTransportAction.class); + ClearCacheRequest clearCacheRequest = new ClearCacheRequest(testIndex); + ClusterState state = ClusterState.builder(ClusterName.DEFAULT).build(); + when(clusterService.state()).thenReturn(state); + assertNull(clearCacheTransportAction.checkGlobalBlock(clusterService.state(), clearCacheRequest)); + } + + public void testCheckRequestBlock_throwsClusterBlockException() { + String testIndex = getTestName().toLowerCase(); + String description = "testing index metadata block"; + ClusterService clusterService = mock(ClusterService.class); + addIndexClusterBlock(clusterService, description, EnumSet.of(ClusterBlockLevel.METADATA_READ), testIndex); + + ClearCacheTransportAction clearCacheTransportAction = node().injector().getInstance(ClearCacheTransportAction.class); + ClearCacheRequest clearCacheRequest = new ClearCacheRequest(testIndex); + ClusterBlockException ex = clearCacheTransportAction.checkRequestBlock( + clusterService.state(), + clearCacheRequest, + new String[] { testIndex } + ); + assertTrue(ex.getMessage().contains(testIndex)); + assertTrue(ex.getMessage().contains(description)); + + } + + public void testCheckRequestBlock_notThrowsClusterBlockException() { + String testIndex = getTestName().toLowerCase(); + ClusterService clusterService = mock(ClusterService.class); + ClearCacheTransportAction clearCacheTransportAction = node().injector().getInstance(ClearCacheTransportAction.class); + ClearCacheRequest clearCacheRequest = new ClearCacheRequest(testIndex); + ClusterState state = ClusterState.builder(ClusterName.DEFAULT).build(); + when(clusterService.state()).thenReturn(state); + assertNull(clearCacheTransportAction.checkRequestBlock(clusterService.state(), clearCacheRequest, new String[] { testIndex })); + } +} diff --git a/src/testFixtures/java/org/opensearch/knn/KNNRestTestCase.java b/src/testFixtures/java/org/opensearch/knn/KNNRestTestCase.java index b0f1b817d5..9ff9dd9c3e 100644 --- a/src/testFixtures/java/org/opensearch/knn/KNNRestTestCase.java +++ b/src/testFixtures/java/org/opensearch/knn/KNNRestTestCase.java @@ -82,6 +82,7 @@ import static org.opensearch.knn.common.KNNConstants.PARAMETERS; import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_EF_CONSTRUCTION; import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_M; +import static org.opensearch.knn.common.KNNConstants.CLEAR_CACHE; import static org.opensearch.knn.TestUtils.NUMBER_OF_REPLICAS; import static org.opensearch.knn.TestUtils.NUMBER_OF_SHARDS; @@ -557,6 +558,20 @@ protected Response executeWarmupRequest(List indices, final String baseU return client().performRequest(request); } + /** + * Evicts valid k-NN indices from the cache. + * + * @param indices list of k-NN indices that needs to be removed from cache + * @return Response of clear Cache API request + * @throws IOException + */ + protected Response clearCache(List indices) throws IOException { + String indicesSuffix = String.join(",", indices); + String restURI = String.join("/", KNNPlugin.KNN_BASE_URI, CLEAR_CACHE, indicesSuffix); + Request request = new Request("GET", restURI); + return client().performRequest(request); + } + /** * Parse KNN Cluster stats from response */