diff --git a/src/main/java/org/opensearch/knn/common/KNNConstants.java b/src/main/java/org/opensearch/knn/common/KNNConstants.java index 6d387eec40..63529fccc0 100644 --- a/src/main/java/org/opensearch/knn/common/KNNConstants.java +++ b/src/main/java/org/opensearch/knn/common/KNNConstants.java @@ -98,4 +98,7 @@ public class KNNConstants { private static final String JNI_LIBRARY_PREFIX = "opensearchknn_"; public static final String FAISS_JNI_LIBRARY_NAME = JNI_LIBRARY_PREFIX + FAISS_NAME; public static final String NMSLIB_JNI_LIBRARY_NAME = JNI_LIBRARY_PREFIX + NMSLIB_NAME; + + // API Constants + public static final String CLEAR_CACHE = "clear_cache"; } diff --git a/src/main/java/org/opensearch/knn/index/KNNIndexShard.java b/src/main/java/org/opensearch/knn/index/KNNIndexShard.java index a12b0df0f6..9eb4860ad9 100644 --- a/src/main/java/org/opensearch/knn/index/KNNIndexShard.java +++ b/src/main/java/org/opensearch/knn/index/KNNIndexShard.java @@ -17,6 +17,7 @@ import org.opensearch.index.engine.Engine; import org.opensearch.index.shard.IndexShard; import org.opensearch.knn.index.mapper.KNNVectorFieldMapper; +import org.opensearch.knn.index.memory.NativeMemoryAllocation; import org.opensearch.knn.index.memory.NativeMemoryCacheManager; import org.opensearch.knn.index.memory.NativeMemoryEntryContext; import org.opensearch.knn.index.memory.NativeMemoryLoadStrategy; @@ -27,6 +28,7 @@ import java.util.Collection; import java.util.HashMap; import java.util.Map; +import java.util.Optional; import java.util.concurrent.ExecutionException; import java.util.stream.Collectors; @@ -100,6 +102,32 @@ public void warmup() throws IOException { } } + /** + * Removes all the k-NN segments for this shard from the cache. + * Adding write lock onto the NativeMemoryAllocation of the index that needs to be evicted from cache. + * Write lock will be unlocked after the index is evicted. This locking mechanism is used to avoid + * conflicts with queries fired on this index when the index is being evicted from cache. + */ + public void clearCache() { + String indexName = getIndexName(); + Optional indexAllocationOptional; + NativeMemoryAllocation indexAllocation; + indexAllocationOptional = nativeMemoryCacheManager.getIndexMemoryAllocation(indexName); + if (indexAllocationOptional.isPresent()) { + indexAllocation = indexAllocationOptional.get(); + indexAllocation.writeLock(); + logger.info("[KNN] Evicting index from cache: [{}]", indexName); + try (Engine.Searcher searcher = indexShard.acquireSearcher("knn-clear-cache")) { + getAllEnginePaths(searcher.getIndexReader()).forEach((key, value) -> nativeMemoryCacheManager.invalidate(key)); + } catch (IOException ex) { + logger.error("[KNN] Failed to evict index from cache: [{}]", indexName); + throw new RuntimeException(ex); + } finally { + indexAllocation.writeUnlock(); + } + } + } + /** * For the given shard, get all of its engine paths * diff --git a/src/main/java/org/opensearch/knn/index/memory/NativeMemoryCacheManager.java b/src/main/java/org/opensearch/knn/index/memory/NativeMemoryCacheManager.java index 8b3a3bce15..9478e1e006 100644 --- a/src/main/java/org/opensearch/knn/index/memory/NativeMemoryCacheManager.java +++ b/src/main/java/org/opensearch/knn/index/memory/NativeMemoryCacheManager.java @@ -27,6 +27,7 @@ import java.io.Closeable; import java.util.HashMap; import java.util.Map; +import java.util.Optional; import java.util.concurrent.ExecutionException; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; @@ -303,6 +304,23 @@ public NativeMemoryAllocation get(NativeMemoryEntryContext nativeMemoryEntryC return cache.get(nativeMemoryEntryContext.getKey(), nativeMemoryEntryContext::load); } + /** + * Returns the NativeMemoryAllocation associated with given index + * @param indexName name of OpenSearch index + * @return NativeMemoryAllocation associated with given index + */ + public Optional getIndexMemoryAllocation(String indexName) { + Validate.notNull(indexName, "Index name cannot be null"); + return cache.asMap() + .values() + .stream() + .filter(nativeMemoryAllocation -> nativeMemoryAllocation instanceof NativeMemoryAllocation.IndexAllocation) + .filter( + indexAllocation -> indexName.equals(((NativeMemoryAllocation.IndexAllocation) indexAllocation).getOpenSearchIndexName()) + ) + .findFirst(); + } + /** * Invalidate entry from the cache. * diff --git a/src/main/java/org/opensearch/knn/plugin/KNNPlugin.java b/src/main/java/org/opensearch/knn/plugin/KNNPlugin.java index 54011c2591..a5c75445a5 100644 --- a/src/main/java/org/opensearch/knn/plugin/KNNPlugin.java +++ b/src/main/java/org/opensearch/knn/plugin/KNNPlugin.java @@ -30,6 +30,7 @@ import org.opensearch.knn.plugin.rest.RestKNNWarmupHandler; import org.opensearch.knn.plugin.rest.RestSearchModelHandler; import org.opensearch.knn.plugin.rest.RestTrainModelHandler; +import org.opensearch.knn.plugin.rest.RestClearCacheHandler; import org.opensearch.knn.plugin.script.KNNScoringScriptEngine; import org.opensearch.knn.plugin.stats.KNNStats; import org.opensearch.knn.plugin.transport.DeleteModelAction; @@ -40,6 +41,8 @@ import org.opensearch.knn.plugin.transport.KNNStatsTransportAction; import org.opensearch.knn.plugin.transport.KNNWarmupAction; import org.opensearch.knn.plugin.transport.KNNWarmupTransportAction; +import org.opensearch.knn.plugin.transport.ClearCacheAction; +import org.opensearch.knn.plugin.transport.ClearCacheTransportAction; import com.google.common.collect.ImmutableList; import org.opensearch.action.ActionRequest; @@ -231,6 +234,7 @@ public List getRestHandlers( RestDeleteModelHandler restDeleteModelHandler = new RestDeleteModelHandler(); RestTrainModelHandler restTrainModelHandler = new RestTrainModelHandler(); RestSearchModelHandler restSearchModelHandler = new RestSearchModelHandler(); + RestClearCacheHandler restClearCacheHandler = new RestClearCacheHandler(clusterService, indexNameExpressionResolver); return ImmutableList.of( restKNNStatsHandler, @@ -238,7 +242,8 @@ public List getRestHandlers( restGetModelHandler, restDeleteModelHandler, restTrainModelHandler, - restSearchModelHandler + restSearchModelHandler, + restClearCacheHandler ); } @@ -258,7 +263,8 @@ public List getRestHandlers( new ActionHandler<>(TrainingModelAction.INSTANCE, TrainingModelTransportAction.class), new ActionHandler<>(RemoveModelFromCacheAction.INSTANCE, RemoveModelFromCacheTransportAction.class), new ActionHandler<>(SearchModelAction.INSTANCE, SearchModelTransportAction.class), - new ActionHandler<>(UpdateModelGraveyardAction.INSTANCE, UpdateModelGraveyardTransportAction.class) + new ActionHandler<>(UpdateModelGraveyardAction.INSTANCE, UpdateModelGraveyardTransportAction.class), + new ActionHandler<>(ClearCacheAction.INSTANCE, ClearCacheTransportAction.class) ); } diff --git a/src/main/java/org/opensearch/knn/plugin/rest/RestClearCacheHandler.java b/src/main/java/org/opensearch/knn/plugin/rest/RestClearCacheHandler.java new file mode 100644 index 0000000000..46ac4d93dc --- /dev/null +++ b/src/main/java/org/opensearch/knn/plugin/rest/RestClearCacheHandler.java @@ -0,0 +1,100 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.plugin.rest; + +import com.google.common.collect.ImmutableList; +import lombok.AllArgsConstructor; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.client.node.NodeClient; +import org.opensearch.cluster.metadata.IndexNameExpressionResolver; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.Strings; +import org.opensearch.index.Index; +import org.opensearch.knn.common.exception.KNNInvalidIndicesException; +import org.opensearch.knn.plugin.KNNPlugin; +import org.opensearch.knn.plugin.transport.ClearCacheAction; +import org.opensearch.knn.plugin.transport.ClearCacheRequest; +import org.opensearch.rest.BaseRestHandler; +import org.opensearch.rest.RestRequest; +import org.opensearch.rest.action.RestToXContentListener; + +import java.util.Arrays; +import java.util.List; +import java.util.Locale; +import java.util.stream.Collectors; + +import static org.opensearch.action.support.IndicesOptions.strictExpandOpen; +import static org.opensearch.knn.common.KNNConstants.CLEAR_CACHE; +import static org.opensearch.knn.index.KNNSettings.KNN_INDEX; + +/** + * RestHandler for k-NN Clear Cache API. API provides the ability for a user to evict those indices from Cache. + */ +@AllArgsConstructor +public class RestClearCacheHandler extends BaseRestHandler { + private static final Logger logger = LogManager.getLogger(RestClearCacheHandler.class); + + private static final String INDEX = "index"; + public static String NAME = "knn_clear_cache_action"; + private final ClusterService clusterService; + private final IndexNameExpressionResolver indexNameExpressionResolver; + + /** + * @return name of Clear Cache API action + */ + @Override + public String getName() { + return NAME; + } + + /** + * @return Immutable List of Clear Cache API endpoint + */ + @Override + public List routes() { + return ImmutableList.of( + new Route(RestRequest.Method.POST, String.format(Locale.ROOT, "%s/%s/{%s}", KNNPlugin.KNN_BASE_URI, CLEAR_CACHE, INDEX)) + ); + } + + /** + * @param request RestRequest + * @param client NodeClient + * @return RestChannelConsumer + */ + @Override + protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient client) { + ClearCacheRequest clearCacheRequest = createClearCacheRequest(request); + logger.info("[KNN] ClearCache started for the following indices: [{}]", String.join(",", clearCacheRequest.indices())); + return channel -> client.execute(ClearCacheAction.INSTANCE, clearCacheRequest, new RestToXContentListener<>(channel)); + } + + // Create a clear cache request by processing the rest request and validating the indices + private ClearCacheRequest createClearCacheRequest(RestRequest request) { + String[] indexNames = Strings.splitStringByCommaToArray(request.param("index")); + Index[] indices = indexNameExpressionResolver.concreteIndices(clusterService.state(), strictExpandOpen(), indexNames); + validateIndices(indices); + + return new ClearCacheRequest(indexNames); + } + + // Validate if the given indices are k-NN indices or not. If there are any invalid indices, + // the request is rejected and an exception is thrown. + private void validateIndices(Index[] indices) { + List invalidIndexNames = Arrays.stream(indices) + .filter(index -> !"true".equals(clusterService.state().metadata().getIndexSafe(index).getSettings().get(KNN_INDEX))) + .map(Index::getName) + .collect(Collectors.toList()); + + if (!invalidIndexNames.isEmpty()) { + throw new KNNInvalidIndicesException( + invalidIndexNames, + "ClearCache request rejected. One or more indices have 'index.knn' set to false." + ); + } + } +} diff --git a/src/main/java/org/opensearch/knn/plugin/transport/ClearCacheAction.java b/src/main/java/org/opensearch/knn/plugin/transport/ClearCacheAction.java new file mode 100644 index 0000000000..12b77eb6ff --- /dev/null +++ b/src/main/java/org/opensearch/knn/plugin/transport/ClearCacheAction.java @@ -0,0 +1,27 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.plugin.transport; + +import org.opensearch.action.ActionType; +import org.opensearch.common.io.stream.Writeable; + +/** + * Action associated with ClearCache + */ +public class ClearCacheAction extends ActionType { + + public static final ClearCacheAction INSTANCE = new ClearCacheAction(); + public static final String NAME = "cluster:admin/clear_cache_action"; + + private ClearCacheAction() { + super(NAME, ClearCacheResponse::new); + } + + @Override + public Writeable.Reader getResponseReader() { + return ClearCacheResponse::new; + } +} diff --git a/src/main/java/org/opensearch/knn/plugin/transport/ClearCacheRequest.java b/src/main/java/org/opensearch/knn/plugin/transport/ClearCacheRequest.java new file mode 100644 index 0000000000..7187681379 --- /dev/null +++ b/src/main/java/org/opensearch/knn/plugin/transport/ClearCacheRequest.java @@ -0,0 +1,36 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.plugin.transport; + +import org.opensearch.action.support.broadcast.BroadcastRequest; +import org.opensearch.common.io.stream.StreamInput; + +import java.io.IOException; + +/** + * Clear Cache Request. This request contains a list of indices which needs to be evicted from Cache. + */ +public class ClearCacheRequest extends BroadcastRequest { + + /** + * Constructor + * + * @param in input stream + * @throws IOException if read from stream fails + */ + public ClearCacheRequest(StreamInput in) throws IOException { + super(in); + } + + /** + * Constructor + * + * @param indices list of indices which needs to be evicted from cache + */ + public ClearCacheRequest(String... indices) { + super(indices); + } +} diff --git a/src/main/java/org/opensearch/knn/plugin/transport/ClearCacheResponse.java b/src/main/java/org/opensearch/knn/plugin/transport/ClearCacheResponse.java new file mode 100644 index 0000000000..416c53413c --- /dev/null +++ b/src/main/java/org/opensearch/knn/plugin/transport/ClearCacheResponse.java @@ -0,0 +1,49 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.plugin.transport; + +import org.opensearch.action.support.DefaultShardOperationFailedException; +import org.opensearch.action.support.broadcast.BroadcastResponse; +import org.opensearch.common.io.stream.StreamInput; +import org.opensearch.core.xcontent.ToXContentObject; + +import java.io.IOException; +import java.util.List; + +/** + * {@link ClearCacheResponse} represents Response returned by {@link ClearCacheRequest}. + * Returns total number of shards on which ClearCache was performed on, as well as + * the number of shards that succeeded and the number of shards that failed. + */ +public class ClearCacheResponse extends BroadcastResponse implements ToXContentObject { + + /** + * Constructor + * + * @param in input stream + * @throws IOException if read from stream fails + */ + public ClearCacheResponse(StreamInput in) throws IOException { + super(in); + } + + /** + * Constructor + * + * @param totalShards total number of shards on which ClearCache was performed + * @param successfulShards number of shards that succeeded + * @param failedShards number of shards that failed + * @param shardFailures list of shard failure exceptions + */ + public ClearCacheResponse( + int totalShards, + int successfulShards, + int failedShards, + List shardFailures + ) { + super(totalShards, successfulShards, failedShards, shardFailures); + } +} diff --git a/src/main/java/org/opensearch/knn/plugin/transport/ClearCacheTransportAction.java b/src/main/java/org/opensearch/knn/plugin/transport/ClearCacheTransportAction.java new file mode 100644 index 0000000000..60f5dd21d4 --- /dev/null +++ b/src/main/java/org/opensearch/knn/plugin/transport/ClearCacheTransportAction.java @@ -0,0 +1,168 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.plugin.transport; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.action.support.ActionFilters; +import org.opensearch.action.support.DefaultShardOperationFailedException; +import org.opensearch.action.support.broadcast.node.TransportBroadcastByNodeAction; +import org.opensearch.cluster.ClusterState; +import org.opensearch.cluster.block.ClusterBlockException; +import org.opensearch.cluster.block.ClusterBlockLevel; +import org.opensearch.cluster.metadata.IndexNameExpressionResolver; +import org.opensearch.cluster.routing.ShardRouting; +import org.opensearch.cluster.routing.ShardsIterator; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.inject.Inject; +import org.opensearch.common.io.stream.StreamInput; +import org.opensearch.index.Index; +import org.opensearch.index.IndexService; +import org.opensearch.index.shard.IndexShard; +import org.opensearch.indices.IndicesService; +import org.opensearch.knn.index.KNNIndexShard; +import org.opensearch.threadpool.ThreadPool; +import org.opensearch.transport.TransportService; + +import java.io.IOException; +import java.util.List; + +/** + * Transport Action to evict k-NN indices from Cache. TransportBroadcastByNodeAction will distribute the request to + * all shards across the cluster for the given indices. For each shard, shardOperation will be called and the + * indices will be cleared from cache. + */ +public class ClearCacheTransportAction extends TransportBroadcastByNodeAction< + ClearCacheRequest, + ClearCacheResponse, + TransportBroadcastByNodeAction.EmptyResult> { + + public static Logger logger = LogManager.getLogger(ClearCacheTransportAction.class); + + private IndicesService indicesService; + + /** + * Constructor + * + * @param clusterService ClusterService + * @param transportService TransportService + * @param actionFilters ActionFilters + * @param indexNameExpressionResolver IndexNameExpressionResolver + * @param indicesService IndicesService + */ + @Inject + public ClearCacheTransportAction( + ClusterService clusterService, + TransportService transportService, + ActionFilters actionFilters, + IndexNameExpressionResolver indexNameExpressionResolver, + IndicesService indicesService + ) { + super( + ClearCacheAction.NAME, + clusterService, + transportService, + actionFilters, + indexNameExpressionResolver, + ClearCacheRequest::new, + ThreadPool.Names.SEARCH + ); + this.indicesService = indicesService; + } + + /** + * @param streamInput StreamInput + * @return EmptyResult + * @throws IOException + */ + @Override + protected EmptyResult readShardResult(StreamInput streamInput) throws IOException { + return EmptyResult.readEmptyResultFrom(streamInput); + } + + /** + * @param request ClearCacheRequest + * @param totalShards total number of shards on which ClearCache was performed + * @param successfulShards number of shards that succeeded + * @param failedShards number of shards that failed + * @param emptyResults List of EmptyResult + * @param shardFailures list of shard failure exceptions + * @param clusterState ClusterState + * @return {@link ClearCacheResponse} + */ + @Override + protected ClearCacheResponse newResponse( + ClearCacheRequest request, + int totalShards, + int successfulShards, + int failedShards, + List emptyResults, + List shardFailures, + ClusterState clusterState + ) { + return new ClearCacheResponse(totalShards, successfulShards, failedShards, shardFailures); + } + + /** + * @param streamInput StreamInput + * @return {@link ClearCacheRequest} + * @throws IOException + */ + @Override + protected ClearCacheRequest readRequestFrom(StreamInput streamInput) throws IOException { + return new ClearCacheRequest(streamInput); + } + + /** + * Operation performed at a shard level on all the shards of given index where the index is removed from the cache. + * + * @param request ClearCacheRequest + * @param shardRouting ShardRouting of given shard + * @return EmptyResult + * @throws IOException + */ + @Override + protected EmptyResult shardOperation(ClearCacheRequest request, ShardRouting shardRouting) throws IOException { + Index index = shardRouting.shardId().getIndex(); + IndexService indexService = indicesService.indexServiceSafe(index); + IndexShard indexShard = indexService.getShard(shardRouting.shardId().id()); + KNNIndexShard knnIndexShard = new KNNIndexShard(indexShard); + knnIndexShard.clearCache(); + return EmptyResult.INSTANCE; + } + + /** + * @param clusterState ClusterState + * @param request ClearCacheRequest + * @param concreteIndices Indices in the request + * @return ShardsIterator with all the shards for given concrete indices + */ + @Override + protected ShardsIterator shards(ClusterState clusterState, ClearCacheRequest request, String[] concreteIndices) { + return clusterState.routingTable().allShards(concreteIndices); + } + + /** + * @param clusterState ClusterState + * @param request ClearCacheRequest + * @return ClusterBlockException if there is any global cluster block at a cluster block level of "METADATA_WRITE" + */ + @Override + protected ClusterBlockException checkGlobalBlock(ClusterState clusterState, ClearCacheRequest request) { + return clusterState.blocks().globalBlockedException(ClusterBlockLevel.METADATA_WRITE); + } + + /** + * @param clusterState ClusterState + * @param request ClearCacheRequest + * @param concreteIndices Indices in the request + * @return ClusterBlockException if there is any cluster block on any of the given indices at a cluster block level of "METADATA_WRITE" + */ + @Override + protected ClusterBlockException checkRequestBlock(ClusterState clusterState, ClearCacheRequest request, String[] concreteIndices) { + return clusterState.blocks().indicesBlockedException(ClusterBlockLevel.METADATA_WRITE, concreteIndices); + } +}