Skip to content

Commit

Permalink
Add Unit and Integration tests
Browse files Browse the repository at this point in the history
Signed-off-by: Naveen Tatikonda <navtat@amazon.com>
  • Loading branch information
naveentatikonda committed Aug 9, 2023
1 parent 24f2309 commit f5b95b1
Show file tree
Hide file tree
Showing 5 changed files with 292 additions and 0 deletions.
30 changes: 30 additions & 0 deletions src/test/java/org/opensearch/knn/KNNSingleNodeTestCase.java
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,12 @@
package org.opensearch.knn;

import org.opensearch.core.common.bytes.BytesReference;
import org.opensearch.cluster.ClusterName;
import org.opensearch.cluster.ClusterState;
import org.opensearch.cluster.block.ClusterBlock;
import org.opensearch.cluster.block.ClusterBlockLevel;
import org.opensearch.cluster.block.ClusterBlocks;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.common.xcontent.XContentHelper;
import org.opensearch.knn.index.query.KNNQueryBuilder;
import org.opensearch.knn.index.memory.NativeMemoryCacheManager;
Expand All @@ -32,9 +38,12 @@
import java.io.IOException;
import java.util.Collection;
import java.util.Collections;
import java.util.EnumSet;
import java.util.Map;
import java.util.concurrent.ExecutionException;

import static org.mockito.Mockito.when;

public class KNNSingleNodeTestCase extends OpenSearchSingleNodeTestCase {
@Override
public void setUp() throws Exception {
Expand Down Expand Up @@ -154,4 +163,25 @@ public void assertTrainingSucceeds(ModelDao modelDao, String modelId, int attemp

fail("Training did not succeed after " + attempts + " attempts with a delay of " + delayInMillis + " ms.");
}

// Add Global Cluster Block with the given ClusterBlockLevel
protected void addGlobalClusterBlock(ClusterService clusterService, String description, EnumSet<ClusterBlockLevel> clusterBlockLevels) {
ClusterBlock block = new ClusterBlock(randomInt(), description, false, false, false, RestStatus.FORBIDDEN, clusterBlockLevels);
ClusterBlocks clusterBlocks = ClusterBlocks.builder().addGlobalBlock(block).build();
ClusterState state = ClusterState.builder(ClusterName.DEFAULT).blocks(clusterBlocks).build();
when(clusterService.state()).thenReturn(state);
}

// Add Cluster Block for an Index with given ClusterBlockLevel
protected void addIndexClusterBlock(
ClusterService clusterService,
String description,
EnumSet<ClusterBlockLevel> clusterBlockLevels,
String testIndex
) {
ClusterBlock block = new ClusterBlock(randomInt(), description, false, false, false, RestStatus.FORBIDDEN, clusterBlockLevels);
ClusterBlocks clusterBlocks = ClusterBlocks.builder().addIndexBlock(testIndex, block).build();
ClusterState state = ClusterState.builder(ClusterName.DEFAULT).blocks(clusterBlocks).build();
when(clusterService.state()).thenReturn(state);
}
}
24 changes: 24 additions & 0 deletions src/test/java/org/opensearch/knn/index/KNNIndexShardTests.java
Original file line number Diff line number Diff line change
Expand Up @@ -152,4 +152,28 @@ public void testGetEnginePaths() {
assertEquals(includedFileNames.size(), included.size());
included.keySet().forEach(o -> assertTrue(includedFileNames.contains(o)));
}

public void testClearCache_emptyIndex() throws IOException {
IndexService indexService = createKNNIndex(testIndexName);
createKnnIndexMapping(testIndexName, testFieldName, dimensions);

IndexShard indexShard = indexService.iterator().next();
KNNIndexShard knnIndexShard = new KNNIndexShard(indexShard);
knnIndexShard.clearCache();
assertNull(NativeMemoryCacheManager.getInstance().getIndicesCacheStats().get(testIndexName));
}

public void testClearCache_shardPresentInCache() throws InterruptedException, ExecutionException, IOException {
IndexService indexService = createKNNIndex(testIndexName);
createKnnIndexMapping(testIndexName, testFieldName, dimensions);
addKnnDoc(testIndexName, String.valueOf(randomInt()), testFieldName, new Float[] { randomFloat(), randomFloat() });

IndexShard indexShard = indexService.iterator().next();
KNNIndexShard knnIndexShard = new KNNIndexShard(indexShard);
knnIndexShard.warmup();
assertEquals(1, NativeMemoryCacheManager.getInstance().getIndicesCacheStats().get(testIndexName).get(GRAPH_COUNT));

knnIndexShard.clearCache();
assertNull(NativeMemoryCacheManager.getInstance().getIndicesCacheStats().get(testIndexName));
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.knn.plugin.action;

import org.opensearch.client.Request;
import org.opensearch.client.ResponseException;
import org.opensearch.common.settings.Settings;
import org.opensearch.knn.KNNRestTestCase;
import org.opensearch.knn.plugin.KNNPlugin;
import org.opensearch.rest.RestRequest;

import java.io.IOException;
import java.util.Arrays;
import java.util.Collections;

import static org.opensearch.knn.common.KNNConstants.CLEAR_CACHE;

/**
* Integration tests to validate ClearCache API
*/

public class RestClearCacheHandlerIT extends KNNRestTestCase {
private static final String TEST_FIELD = "test-field";
private static final int DIMENSIONS = 2;
private static final String ALL_INDICES = "_all";

public void testNonExistentIndex() throws IOException {
String nonExistentIndex = "non-existent-index";

String restURI = String.join("/", KNNPlugin.KNN_BASE_URI, CLEAR_CACHE, nonExistentIndex);
Request request = new Request(RestRequest.Method.POST.name(), restURI);

ResponseException ex = expectThrows(ResponseException.class, () -> client().performRequest(request));
assertTrue(ex.getMessage().contains(nonExistentIndex));
}

public void testNotKnnIndex() throws IOException {
String notKNNIndex = "not-knn-index";
createIndex(notKNNIndex, Settings.EMPTY);

String restURI = String.join("/", KNNPlugin.KNN_BASE_URI, CLEAR_CACHE, notKNNIndex);
Request request = new Request(RestRequest.Method.POST.name(), restURI);

ResponseException ex = expectThrows(ResponseException.class, () -> client().performRequest(request));
assertTrue(ex.getMessage().contains(notKNNIndex));
}

public void testClearCacheSingleIndex() throws Exception {
String testIndex = getTestName().toLowerCase();
int graphCountBefore = getTotalGraphsInCache();
createKnnIndex(testIndex, getKNNDefaultIndexSettings(), createKnnIndexMapping(TEST_FIELD, DIMENSIONS));
addKnnDoc(testIndex, String.valueOf(randomInt()), TEST_FIELD, new Float[] { randomFloat(), randomFloat() });

knnWarmup(Collections.singletonList(testIndex));

assertEquals(graphCountBefore + 1, getTotalGraphsInCache());

clearCache(Collections.singletonList(testIndex));
assertEquals(graphCountBefore, getTotalGraphsInCache());
}

public void testClearCacheMultipleIndices() throws Exception {
String testIndex1 = getTestName().toLowerCase();
String testIndex2 = getTestName().toLowerCase() + 1;
int graphCountBefore = getTotalGraphsInCache();

createKnnIndex(testIndex1, getKNNDefaultIndexSettings(), createKnnIndexMapping(TEST_FIELD, DIMENSIONS));
addKnnDoc(testIndex1, String.valueOf(randomInt()), TEST_FIELD, new Float[] { randomFloat(), randomFloat() });

createKnnIndex(testIndex2, getKNNDefaultIndexSettings(), createKnnIndexMapping(TEST_FIELD, DIMENSIONS));
addKnnDoc(testIndex2, String.valueOf(randomInt()), TEST_FIELD, new Float[] { randomFloat(), randomFloat() });

knnWarmup(Arrays.asList(testIndex1, testIndex2));

assertEquals(graphCountBefore + 2, getTotalGraphsInCache());

clearCache(Arrays.asList(ALL_INDICES));
assertEquals(graphCountBefore, getTotalGraphsInCache());
}

public void testClearCacheMultipleIndicesWithPatterns() throws Exception {
String testIndex1 = getTestName().toLowerCase();
String testIndex2 = getTestName().toLowerCase() + 1;
String testIndex3 = "abc" + getTestName().toLowerCase();
int graphCountBefore = getTotalGraphsInCache();

createKnnIndex(testIndex1, getKNNDefaultIndexSettings(), createKnnIndexMapping(TEST_FIELD, DIMENSIONS));
addKnnDoc(testIndex1, String.valueOf(randomInt()), TEST_FIELD, new Float[] { randomFloat(), randomFloat() });

createKnnIndex(testIndex2, getKNNDefaultIndexSettings(), createKnnIndexMapping(TEST_FIELD, DIMENSIONS));
addKnnDoc(testIndex2, String.valueOf(randomInt()), TEST_FIELD, new Float[] { randomFloat(), randomFloat() });

createKnnIndex(testIndex3, getKNNDefaultIndexSettings(), createKnnIndexMapping(TEST_FIELD, DIMENSIONS));
addKnnDoc(testIndex3, String.valueOf(randomInt()), TEST_FIELD, new Float[] { randomFloat(), randomFloat() });

knnWarmup(Arrays.asList(testIndex1, testIndex2, testIndex3));

assertEquals(graphCountBefore + 3, getTotalGraphsInCache());
String indexPattern = getTestName().toLowerCase() + "*";

clearCache(Arrays.asList(indexPattern));
assertEquals(graphCountBefore + 1, getTotalGraphsInCache());
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.knn.plugin.transport;

import org.opensearch.cluster.ClusterName;
import org.opensearch.cluster.ClusterState;
import org.opensearch.cluster.block.ClusterBlockException;
import org.opensearch.cluster.block.ClusterBlockLevel;
import org.opensearch.cluster.routing.ShardRouting;
import org.opensearch.cluster.routing.ShardsIterator;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.index.IndexService;
import org.opensearch.knn.KNNSingleNodeTestCase;
import org.opensearch.knn.index.memory.NativeMemoryCacheManager;

import java.io.IOException;
import java.util.EnumSet;
import java.util.concurrent.ExecutionException;

import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;

public class ClearCacheTransportActionTests extends KNNSingleNodeTestCase {
private static final String TEST_FIELD = "test-field";
private static final int DIMENSIONS = 2;

public void testShardOperation() throws IOException, ExecutionException, InterruptedException {
String testIndex = getTestName().toLowerCase();
KNNWarmupRequest knnWarmupRequest = new KNNWarmupRequest(testIndex);
KNNWarmupTransportAction knnWarmupTransportAction = node().injector().getInstance(KNNWarmupTransportAction.class);
assertEquals(0, NativeMemoryCacheManager.getInstance().getIndicesCacheStats().size());

IndexService indexService = createKNNIndex(testIndex);
createKnnIndexMapping(testIndex, TEST_FIELD, DIMENSIONS);
addKnnDoc(testIndex, String.valueOf(randomInt()), TEST_FIELD, new Float[] { randomFloat(), randomFloat() });
ShardRouting shardRouting = indexService.iterator().next().routingEntry();

knnWarmupTransportAction.shardOperation(knnWarmupRequest, shardRouting);
assertEquals(1, NativeMemoryCacheManager.getInstance().getIndicesCacheStats().size());

ClearCacheRequest clearCacheRequest = new ClearCacheRequest(testIndex);
ClearCacheTransportAction clearCacheTransportAction = node().injector().getInstance(ClearCacheTransportAction.class);
clearCacheTransportAction.shardOperation(clearCacheRequest, shardRouting);
assertEquals(0, NativeMemoryCacheManager.getInstance().getIndicesCacheStats().size());
}

public void testShards() throws InterruptedException, ExecutionException, IOException {
String testIndex = getTestName().toLowerCase();
ClusterService clusterService = node().injector().getInstance(ClusterService.class);
ClearCacheTransportAction clearCacheTransportAction = node().injector().getInstance(ClearCacheTransportAction.class);
ClearCacheRequest clearCacheRequest = new ClearCacheRequest(testIndex);

createKNNIndex(testIndex);
createKnnIndexMapping(testIndex, TEST_FIELD, DIMENSIONS);
addKnnDoc(testIndex, String.valueOf(randomInt()), TEST_FIELD, new Float[] { randomFloat(), randomFloat() });

ShardsIterator shardsIterator = clearCacheTransportAction.shards(
clusterService.state(),
clearCacheRequest,
new String[] { testIndex }
);
assertEquals(1, shardsIterator.size());
}

public void testCheckGlobalBlock_throwsClusterBlockException() {
String testIndex = getTestName().toLowerCase();
String description = "testing metadata block";
ClusterService clusterService = mock(ClusterService.class);
addGlobalClusterBlock(clusterService, description, EnumSet.of(ClusterBlockLevel.METADATA_WRITE));
ClearCacheTransportAction clearCacheTransportAction = node().injector().getInstance(ClearCacheTransportAction.class);
ClearCacheRequest clearCacheRequest = new ClearCacheRequest(testIndex);
ClusterBlockException ex = clearCacheTransportAction.checkGlobalBlock(clusterService.state(), clearCacheRequest);
assertTrue(ex.getMessage().contains(description));
}

public void testCheckGlobalBlock_notThrowsClusterBlockException() {
String testIndex = getTestName().toLowerCase();
ClusterService clusterService = mock(ClusterService.class);
ClearCacheTransportAction clearCacheTransportAction = node().injector().getInstance(ClearCacheTransportAction.class);
ClearCacheRequest clearCacheRequest = new ClearCacheRequest(testIndex);
ClusterState state = ClusterState.builder(ClusterName.DEFAULT).build();
when(clusterService.state()).thenReturn(state);
assertNull(clearCacheTransportAction.checkGlobalBlock(clusterService.state(), clearCacheRequest));
}

public void testCheckRequestBlock_throwsClusterBlockException() {
String testIndex = getTestName().toLowerCase();
String description = "testing index metadata block";
ClusterService clusterService = mock(ClusterService.class);
addIndexClusterBlock(clusterService, description, EnumSet.of(ClusterBlockLevel.METADATA_WRITE), testIndex);

ClearCacheTransportAction clearCacheTransportAction = node().injector().getInstance(ClearCacheTransportAction.class);
ClearCacheRequest clearCacheRequest = new ClearCacheRequest(testIndex);
ClusterBlockException ex = clearCacheTransportAction.checkRequestBlock(
clusterService.state(),
clearCacheRequest,
new String[] { testIndex }
);
assertTrue(ex.getMessage().contains(testIndex));
assertTrue(ex.getMessage().contains(description));

}

public void testCheckRequestBlock_notThrowsClusterBlockException() {
String testIndex = getTestName().toLowerCase();
ClusterService clusterService = mock(ClusterService.class);
ClearCacheTransportAction clearCacheTransportAction = node().injector().getInstance(ClearCacheTransportAction.class);
ClearCacheRequest clearCacheRequest = new ClearCacheRequest(testIndex);
ClusterState state = ClusterState.builder(ClusterName.DEFAULT).build();
when(clusterService.state()).thenReturn(state);
assertNull(clearCacheTransportAction.checkRequestBlock(clusterService.state(), clearCacheRequest, new String[] { testIndex }));
}
}
15 changes: 15 additions & 0 deletions src/testFixtures/java/org/opensearch/knn/KNNRestTestCase.java
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@
import static org.opensearch.knn.common.KNNConstants.PARAMETERS;
import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_EF_CONSTRUCTION;
import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_M;
import static org.opensearch.knn.common.KNNConstants.CLEAR_CACHE;

import static org.opensearch.knn.TestUtils.NUMBER_OF_REPLICAS;
import static org.opensearch.knn.TestUtils.NUMBER_OF_SHARDS;
Expand Down Expand Up @@ -574,6 +575,20 @@ protected Response executeWarmupRequest(List<String> indices, final String baseU
return client().performRequest(request);
}

/**
* Evicts valid k-NN indices from the cache.
*
* @param indices list of k-NN indices that needs to be removed from cache
* @return Response of clear Cache API request
* @throws IOException
*/
protected Response clearCache(List<String> indices) throws IOException {
String indicesSuffix = String.join(",", indices);
String restURI = String.join("/", KNNPlugin.KNN_BASE_URI, CLEAR_CACHE, indicesSuffix);
Request request = new Request("POST", restURI);
return client().performRequest(request);
}

/**
* Parse KNN Cluster stats from response
*/
Expand Down

0 comments on commit f5b95b1

Please sign in to comment.