From fd74d639993ace93393e32399f3d655d34ba2816 Mon Sep 17 00:00:00 2001 From: Patrick Zhai Date: Sun, 30 Apr 2023 14:24:31 -0700 Subject: [PATCH] Add multi-thread searchability to OnHeapHnswGraph --- lucene/CHANGES.txt | 3 +- .../lucene/util/hnsw/HnswGraphSearcher.java | 145 +++++++++++++++--- .../lucene/util/hnsw/HnswGraphTestCase.java | 106 +++++++++++++ 3 files changed, 230 insertions(+), 24 deletions(-) diff --git a/lucene/CHANGES.txt b/lucene/CHANGES.txt index 4f981b13cc34..026b1cc38a59 100644 --- a/lucene/CHANGES.txt +++ b/lucene/CHANGES.txt @@ -118,7 +118,8 @@ API Changes New Features --------------------- -(No changes) + +* GITHUB#12257: Create OnHeapHnswGraphSearcher to let OnHeapHnswGraph to be searched in a thread-safety manner. (Patrick Zhai) Improvements --------------------- diff --git a/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphSearcher.java b/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphSearcher.java index 4857d5b9d577..d6e63f483b2e 100644 --- a/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphSearcher.java +++ b/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphSearcher.java @@ -100,28 +100,31 @@ public static NeighborQueue search( similarityFunction, new NeighborQueue(topK, true), new SparseFixedBitSet(vectors.size())); - NeighborQueue results; + return search(query, topK, vectors, graph, graphSearcher, acceptOrds, visitedLimit); + } - int initialEp = graph.entryNode(); - if (initialEp == -1) { - return new NeighborQueue(1, true); - } - int[] eps = new int[] {initialEp}; - int numVisited = 0; - for (int level = graph.numLevels() - 1; level >= 1; level--) { - results = graphSearcher.searchLevel(query, 1, level, eps, vectors, graph, null, visitedLimit); - numVisited += results.visitedCount(); - visitedLimit -= results.visitedCount(); - if (results.incomplete()) { - results.setVisitedCount(numVisited); - return results; - } - eps[0] = results.pop(); - } - results = - graphSearcher.searchLevel(query, topK, 0, eps, vectors, graph, acceptOrds, visitedLimit); - results.setVisitedCount(results.visitedCount() + numVisited); - return results; + /** + * Search {@link OnHeapHnswGraph}, this method is thread safe, for parameters please refer to + * {@link #search(float[], int, RandomAccessVectorValues, VectorEncoding, + * VectorSimilarityFunction, HnswGraph, Bits, int)} + */ + public static NeighborQueue search( + float[] query, + int topK, + RandomAccessVectorValues vectors, + VectorEncoding vectorEncoding, + VectorSimilarityFunction similarityFunction, + OnHeapHnswGraph graph, + Bits acceptOrds, + int visitedLimit) + throws IOException { + OnHeapHnswGraphSearcher graphSearcher = + new OnHeapHnswGraphSearcher<>( + vectorEncoding, + similarityFunction, + new NeighborQueue(topK, true), + new SparseFixedBitSet(vectors.size())); + return search(query, topK, vectors, graph, graphSearcher, acceptOrds, visitedLimit); } /** @@ -161,6 +164,46 @@ public static NeighborQueue search( similarityFunction, new NeighborQueue(topK, true), new SparseFixedBitSet(vectors.size())); + return search(query, topK, vectors, graph, graphSearcher, acceptOrds, visitedLimit); + } + + /** + * Search {@link OnHeapHnswGraph}, this method is thread safe, for parameters please refer to + * {@link #search(byte[], int, RandomAccessVectorValues, VectorEncoding, VectorSimilarityFunction, + * HnswGraph, Bits, int)} + */ + public static NeighborQueue search( + byte[] query, + int topK, + RandomAccessVectorValues vectors, + VectorEncoding vectorEncoding, + VectorSimilarityFunction similarityFunction, + OnHeapHnswGraph graph, + Bits acceptOrds, + int visitedLimit) + throws IOException { + OnHeapHnswGraphSearcher graphSearcher = + new OnHeapHnswGraphSearcher<>( + vectorEncoding, + similarityFunction, + new NeighborQueue(topK, true), + new SparseFixedBitSet(vectors.size())); + return search(query, topK, vectors, graph, graphSearcher, acceptOrds, visitedLimit); + } + + private static NeighborQueue search( + T query, + int topK, + RandomAccessVectorValues vectors, + HnswGraph graph, + HnswGraphSearcher graphSearcher, + Bits acceptOrds, + int visitedLimit) + throws IOException { + int initialEp = graph.entryNode(); + if (initialEp == -1) { + return new NeighborQueue(1, true); + } NeighborQueue results; int[] eps = new int[] {graph.entryNode()}; int numVisited = 0; @@ -252,9 +295,9 @@ private NeighborQueue searchLevel( } int topCandidateNode = candidates.pop(); - graph.seek(level, topCandidateNode); + graphSeek(graph, level, topCandidateNode); int friendOrd; - while ((friendOrd = graph.nextNeighbor()) != NO_MORE_DOCS) { + while ((friendOrd = graphNextNeighbor(graph)) != NO_MORE_DOCS) { assert friendOrd < size : "friendOrd=" + friendOrd + "; size=" + size; if (visited.getAndSet(friendOrd)) { continue; @@ -298,4 +341,60 @@ private void prepareScratchState(int capacity) { } visited.clear(0, visited.length()); } + + /** + * Seek a specific node in the given graph. The default implementation will just call {@link + * HnswGraph#seek(int, int)} + * + * @throws IOException when seeking the graph + */ + void graphSeek(HnswGraph graph, int level, int targetNode) throws IOException { + graph.seek(level, targetNode); + } + + /** + * Get the next neighbor from the graph, you must call {@link #graphSeek(HnswGraph, int, int)} + * before calling this method. The default implementation will just call {@link + * HnswGraph#nextNeighbor()} + * + * @return see {@link HnswGraph#nextNeighbor()} + * @throws IOException when advance neighbors + */ + int graphNextNeighbor(HnswGraph graph) throws IOException { + return graph.nextNeighbor(); + } + + /** + * This class allow {@link OnHeapHnswGraph} to be searched in a thread-safe manner. + * + *

Note the class itself is NOT thread safe, but since each search will create one new graph + * searcher the search method is thread safe. + */ + private static class OnHeapHnswGraphSearcher extends HnswGraphSearcher { + + private NeighborArray cur; + private int upto; + + private OnHeapHnswGraphSearcher( + VectorEncoding vectorEncoding, + VectorSimilarityFunction similarityFunction, + NeighborQueue candidates, + BitSet visited) { + super(vectorEncoding, similarityFunction, candidates, visited); + } + + @Override + void graphSeek(HnswGraph graph, int level, int targetNode) { + cur = ((OnHeapHnswGraph) graph).getNeighbors(level, targetNode); + upto = -1; + } + + @Override + int graphNextNeighbor(HnswGraph graph) { + if (++upto < cur.size()) { + return cur.node[upto]; + } + return NO_MORE_DOCS; + } + } } diff --git a/lucene/core/src/test/org/apache/lucene/util/hnsw/HnswGraphTestCase.java b/lucene/core/src/test/org/apache/lucene/util/hnsw/HnswGraphTestCase.java index 9825d4a5f419..e74660405b0d 100644 --- a/lucene/core/src/test/org/apache/lucene/util/hnsw/HnswGraphTestCase.java +++ b/lucene/core/src/test/org/apache/lucene/util/hnsw/HnswGraphTestCase.java @@ -33,6 +33,12 @@ import java.util.Map; import java.util.Random; import java.util.Set; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.Future; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; import java.util.stream.Collectors; import org.apache.lucene.codecs.KnnVectorsFormat; import org.apache.lucene.codecs.lucene95.Lucene95Codec; @@ -67,6 +73,7 @@ import org.apache.lucene.util.BitSet; import org.apache.lucene.util.Bits; import org.apache.lucene.util.FixedBitSet; +import org.apache.lucene.util.NamedThreadFactory; import org.apache.lucene.util.RamUsageEstimator; import org.apache.lucene.util.VectorUtil; import org.apache.lucene.util.hnsw.HnswGraph.NodesIterator; @@ -991,6 +998,105 @@ public void testRandom() throws IOException { assertTrue("overlap=" + overlap, overlap > 0.9); } + /* test thread-safety of searching OnHeapHnswGraph */ + @SuppressWarnings("unchecked") + public void testOnHeapHnswGraphSearch() + throws IOException, ExecutionException, InterruptedException, TimeoutException { + int size = atLeast(100); + int dim = atLeast(10); + AbstractMockVectorValues vectors = vectorValues(size, dim); + int topK = 5; + HnswGraphBuilder builder = + HnswGraphBuilder.create( + vectors, getVectorEncoding(), similarityFunction, 10, 30, random().nextLong()); + OnHeapHnswGraph hnsw = builder.build(vectors.copy()); + Bits acceptOrds = random().nextBoolean() ? null : createRandomAcceptOrds(0, size); + + List queries = new ArrayList<>(); + List expects = new ArrayList<>(); + for (int i = 0; i < 100; i++) { + NeighborQueue expect; + T query = randomVector(dim); + queries.add(query); + expect = + switch (getVectorEncoding()) { + case BYTE -> HnswGraphSearcher.search( + (byte[]) query, + 100, + (RandomAccessVectorValues) vectors, + getVectorEncoding(), + similarityFunction, + hnsw, + acceptOrds, + Integer.MAX_VALUE); + case FLOAT32 -> HnswGraphSearcher.search( + (float[]) query, + 100, + (RandomAccessVectorValues) vectors, + getVectorEncoding(), + similarityFunction, + hnsw, + acceptOrds, + Integer.MAX_VALUE); + }; + + while (expect.size() > topK) { + expect.pop(); + } + expects.add(expect); + } + + ExecutorService exec = + Executors.newFixedThreadPool(4, new NamedThreadFactory("onHeapHnswSearch")); + List> futures = new ArrayList<>(); + for (T query : queries) { + futures.add( + exec.submit( + () -> { + NeighborQueue actual; + try { + actual = + switch (getVectorEncoding()) { + case BYTE -> HnswGraphSearcher.search( + (byte[]) query, + 100, + (RandomAccessVectorValues) vectors, + getVectorEncoding(), + similarityFunction, + hnsw, + acceptOrds, + Integer.MAX_VALUE); + case FLOAT32 -> HnswGraphSearcher.search( + (float[]) query, + 100, + (RandomAccessVectorValues) vectors, + getVectorEncoding(), + similarityFunction, + hnsw, + acceptOrds, + Integer.MAX_VALUE); + }; + } catch (IOException ioe) { + throw new RuntimeException(ioe); + } + while (actual.size() > topK) { + actual.pop(); + } + return actual; + })); + } + List actuals = new ArrayList<>(); + for (Future future : futures) { + actuals.add(future.get(10, TimeUnit.SECONDS)); + } + exec.shutdownNow(); + for (int i = 0; i < expects.size(); i++) { + NeighborQueue expect = expects.get(i); + NeighborQueue actual = actuals.get(i); + assertArrayEquals(expect.nodes(), actual.nodes()); + } + } + private int computeOverlap(int[] a, int[] b) { Arrays.sort(a); Arrays.sort(b);