diff --git a/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphSearcher.java b/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphSearcher.java index d2c76db4678a..52c1691664ae 100644 --- a/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphSearcher.java +++ b/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphSearcher.java @@ -66,9 +66,6 @@ public HnswGraphSearcher( /** * Searches HNSW graph for the nearest neighbors of a query vector. * - *

Note: if you want to search {@link OnHeapHnswGraph} in a thread-safety manner, please - * consider using {@link OnHeapHnswGraphSearcher} - * * @param query search query vector * @param topK the number of nodes to be returned * @param vectors the vector values @@ -106,12 +103,33 @@ public static NeighborQueue search( return search(query, topK, vectors, graph, graphSearcher, acceptOrds, visitedLimit); } + /** + * Search {@link OnHeapHnswGraph}, this method is thread safe, for parameters please refer to + * {@link #search(float[], int, RandomAccessVectorValues, VectorEncoding, + * VectorSimilarityFunction, HnswGraph, Bits, int)} + */ + public static NeighborQueue search( + float[] query, + int topK, + RandomAccessVectorValues vectors, + VectorEncoding vectorEncoding, + VectorSimilarityFunction similarityFunction, + OnHeapHnswGraph graph, + Bits acceptOrds, + int visitedLimit) + throws IOException { + OnHeapHnswGraphSearcher graphSearcher = + new OnHeapHnswGraphSearcher<>( + vectorEncoding, + similarityFunction, + new NeighborQueue(topK, true), + new SparseFixedBitSet(vectors.size())); + return search(query, topK, vectors, graph, graphSearcher, acceptOrds, visitedLimit); + } + /** * Searches HNSW graph for the nearest neighbors of a query vector. * - *

Note: if you want to search {@link OnHeapHnswGraph} in a thread-safety manner, please - * consider using {@link OnHeapHnswGraphSearcher} - * * @param query search query vector * @param topK the number of nodes to be returned * @param vectors the vector values @@ -149,6 +167,30 @@ public static NeighborQueue search( return search(query, topK, vectors, graph, graphSearcher, acceptOrds, visitedLimit); } + /** + * Search {@link OnHeapHnswGraph}, this method is thread safe, for parameters please refer to + * {@link #search(byte[], int, RandomAccessVectorValues, VectorEncoding, VectorSimilarityFunction, + * HnswGraph, Bits, int)} + */ + public static NeighborQueue search( + byte[] query, + int topK, + RandomAccessVectorValues vectors, + VectorEncoding vectorEncoding, + VectorSimilarityFunction similarityFunction, + OnHeapHnswGraph graph, + Bits acceptOrds, + int visitedLimit) + throws IOException { + OnHeapHnswGraphSearcher graphSearcher = + new OnHeapHnswGraphSearcher<>( + vectorEncoding, + similarityFunction, + new NeighborQueue(topK, true), + new SparseFixedBitSet(vectors.size())); + return search(query, topK, vectors, graph, graphSearcher, acceptOrds, visitedLimit); + } + static NeighborQueue search( T query, int topK, @@ -321,4 +363,38 @@ void graphSeek(HnswGraph graph, int level, int targetNode) throws IOException { int graphNextNeighbor(HnswGraph graph) throws IOException { return graph.nextNeighbor(); } + + /** + * This class allow {@link OnHeapHnswGraph} to be searched in a thread-safe manner. + * + *

Note the class itself is NOT thread safe, but since each search will create one new graph + * searcher the search method is thread safe. + */ + private static class OnHeapHnswGraphSearcher extends HnswGraphSearcher { + + private NeighborArray cur; + private int upto; + + private OnHeapHnswGraphSearcher( + VectorEncoding vectorEncoding, + VectorSimilarityFunction similarityFunction, + NeighborQueue candidates, + BitSet visited) { + super(vectorEncoding, similarityFunction, candidates, visited); + } + + @Override + void graphSeek(HnswGraph graph, int level, int targetNode) { + cur = ((OnHeapHnswGraph) graph).getNeighbors(level, targetNode); + upto = -1; + } + + @Override + int graphNextNeighbor(HnswGraph graph) { + if (++upto < cur.size()) { + return cur.node[upto]; + } + return NO_MORE_DOCS; + } + } } diff --git a/lucene/core/src/java/org/apache/lucene/util/hnsw/OnHeapHnswGraphSearcher.java b/lucene/core/src/java/org/apache/lucene/util/hnsw/OnHeapHnswGraphSearcher.java deleted file mode 100644 index 6374dc08f1e5..000000000000 --- a/lucene/core/src/java/org/apache/lucene/util/hnsw/OnHeapHnswGraphSearcher.java +++ /dev/null @@ -1,141 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.lucene.util.hnsw; - -import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS; - -import java.io.IOException; -import org.apache.lucene.index.VectorEncoding; -import org.apache.lucene.index.VectorSimilarityFunction; -import org.apache.lucene.util.BitSet; -import org.apache.lucene.util.Bits; -import org.apache.lucene.util.SparseFixedBitSet; - -/** - * This class allow {@link OnHeapHnswGraph} to be searched in a thread-safe manner. - * - *

Note the class itself is NOT thread safe, but since each search will create one new graph - * searcher the search method is thread safe. - */ -public class OnHeapHnswGraphSearcher extends HnswGraphSearcher { - - private NeighborArray cur; - private int upto; - - private OnHeapHnswGraphSearcher( - VectorEncoding vectorEncoding, - VectorSimilarityFunction similarityFunction, - NeighborQueue candidates, - BitSet visited) { - super(vectorEncoding, similarityFunction, candidates, visited); - } - - /** - * Searches HNSW graph for the nearest neighbors of a query vector. - * - * @param query search query vector - * @param topK the number of nodes to be returned - * @param vectors the vector values - * @param similarityFunction the similarity function to compare vectors - * @param graph the graph values. May represent the entire graph, or a level in a hierarchical - * graph. - * @param acceptOrds {@link Bits} that represents the allowed document ordinals to match, or - * {@code null} if they are all allowed to match. - * @param visitedLimit the maximum number of nodes that the search is allowed to visit - * @return a priority queue holding the closest neighbors found - */ - public static NeighborQueue search( - float[] query, - int topK, - RandomAccessVectorValues vectors, - VectorEncoding vectorEncoding, - VectorSimilarityFunction similarityFunction, - OnHeapHnswGraph graph, - Bits acceptOrds, - int visitedLimit) - throws IOException { - if (query.length != vectors.dimension()) { - throw new IllegalArgumentException( - "vector query dimension: " - + query.length - + " differs from field dimension: " - + vectors.dimension()); - } - OnHeapHnswGraphSearcher graphSearcher = - new OnHeapHnswGraphSearcher<>( - vectorEncoding, - similarityFunction, - new NeighborQueue(topK, true), - new SparseFixedBitSet(vectors.size())); - return search(query, topK, vectors, graph, graphSearcher, acceptOrds, visitedLimit); - } - - /** - * Searches HNSW graph for the nearest neighbors of a query vector. - * - * @param query search query vector - * @param topK the number of nodes to be returned - * @param vectors the vector values - * @param similarityFunction the similarity function to compare vectors - * @param graph the graph values. May represent the entire graph, or a level in a hierarchical - * graph. - * @param acceptOrds {@link Bits} that represents the allowed document ordinals to match, or - * {@code null} if they are all allowed to match. - * @param visitedLimit the maximum number of nodes that the search is allowed to visit - * @return a priority queue holding the closest neighbors found - */ - public static NeighborQueue search( - byte[] query, - int topK, - RandomAccessVectorValues vectors, - VectorEncoding vectorEncoding, - VectorSimilarityFunction similarityFunction, - OnHeapHnswGraph graph, - Bits acceptOrds, - int visitedLimit) - throws IOException { - if (query.length != vectors.dimension()) { - throw new IllegalArgumentException( - "vector query dimension: " - + query.length - + " differs from field dimension: " - + vectors.dimension()); - } - OnHeapHnswGraphSearcher graphSearcher = - new OnHeapHnswGraphSearcher<>( - vectorEncoding, - similarityFunction, - new NeighborQueue(topK, true), - new SparseFixedBitSet(vectors.size())); - return search(query, topK, vectors, graph, graphSearcher, acceptOrds, visitedLimit); - } - - @Override - void graphSeek(HnswGraph graph, int level, int targetNode) { - cur = ((OnHeapHnswGraph) graph).getNeighbors(level, targetNode); - upto = -1; - } - - @Override - int graphNextNeighbor(HnswGraph graph) { - if (++upto < cur.size()) { - return cur.node[upto]; - } - return NO_MORE_DOCS; - } -} diff --git a/lucene/core/src/test/org/apache/lucene/util/hnsw/HnswGraphTestCase.java b/lucene/core/src/test/org/apache/lucene/util/hnsw/HnswGraphTestCase.java index 7c14a77b081d..65e286d730a9 100644 --- a/lucene/core/src/test/org/apache/lucene/util/hnsw/HnswGraphTestCase.java +++ b/lucene/core/src/test/org/apache/lucene/util/hnsw/HnswGraphTestCase.java @@ -962,7 +962,7 @@ public void testRandom() throws IOException { assertTrue("overlap=" + overlap, overlap > 0.9); } - /* test using multi-thread to search OnHeapHnswGraph with OnHeapHnswGraphSearcher */ + /* test thread-safety of searching OnHeapHnswGraph */ @SuppressWarnings("unchecked") public void testOnHeapHnswGraphSearch() throws IOException, ExecutionException, InterruptedException, TimeoutException { @@ -1021,7 +1021,7 @@ public void testOnHeapHnswGraphSearch() try { actual = switch (getVectorEncoding()) { - case BYTE -> OnHeapHnswGraphSearcher.search( + case BYTE -> HnswGraphSearcher.search( (byte[]) query, 100, (RandomAccessVectorValues) vectors, @@ -1030,7 +1030,7 @@ public void testOnHeapHnswGraphSearch() hnsw, acceptOrds, Integer.MAX_VALUE); - case FLOAT32 -> OnHeapHnswGraphSearcher.search( + case FLOAT32 -> HnswGraphSearcher.search( (float[]) query, 100, (RandomAccessVectorValues) vectors,