From 33f0a5d4c6c31d66702eaa67292ce59d48490690 Mon Sep 17 00:00:00 2001 From: Jonathan Ellis Date: Fri, 28 Apr 2023 20:22:31 -0500 Subject: [PATCH] allocate one NeighborQueue per search for results --- .../lucene/util/hnsw/HnswGraphSearcher.java | 35 ++++++++++++------- 1 file changed, 23 insertions(+), 12 deletions(-) diff --git a/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphSearcher.java b/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphSearcher.java index 4857d5b9d577..84d40c841bf8 100644 --- a/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphSearcher.java +++ b/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphSearcher.java @@ -100,7 +100,7 @@ public static NeighborQueue search( similarityFunction, new NeighborQueue(topK, true), new SparseFixedBitSet(vectors.size())); - NeighborQueue results; + NeighborQueue results = new NeighborQueue(topK, false); int initialEp = graph.entryNode(); if (initialEp == -1) { @@ -109,7 +109,8 @@ public static NeighborQueue search( int[] eps = new int[] {initialEp}; int numVisited = 0; for (int level = graph.numLevels() - 1; level >= 1; level--) { - results = graphSearcher.searchLevel(query, 1, level, eps, vectors, graph, null, visitedLimit); + results.clear(); + graphSearcher.searchLevel(results, query, 1, level, eps, vectors, graph, null, visitedLimit); numVisited += results.visitedCount(); visitedLimit -= results.visitedCount(); if (results.incomplete()) { @@ -118,8 +119,9 @@ public static NeighborQueue search( } eps[0] = results.pop(); } - results = - graphSearcher.searchLevel(query, topK, 0, eps, vectors, graph, acceptOrds, visitedLimit); + results.clear(); + graphSearcher.searchLevel( + results, query, topK, 0, eps, vectors, graph, acceptOrds, visitedLimit); results.setVisitedCount(results.visitedCount() + numVisited); return results; } @@ -161,11 +163,12 @@ public static NeighborQueue search( similarityFunction, new NeighborQueue(topK, true), new SparseFixedBitSet(vectors.size())); - NeighborQueue results; + NeighborQueue results = new NeighborQueue(topK, false); int[] eps = new int[] {graph.entryNode()}; int numVisited = 0; for (int level = graph.numLevels() - 1; level >= 1; level--) { - results = graphSearcher.searchLevel(query, 1, level, eps, vectors, graph, null, visitedLimit); + results.clear(); + graphSearcher.searchLevel(results, query, 1, level, eps, vectors, graph, null, visitedLimit); numVisited += results.visitedCount(); visitedLimit -= results.visitedCount(); @@ -176,8 +179,9 @@ public static NeighborQueue search( } eps[0] = results.pop(); } - results = - graphSearcher.searchLevel(query, topK, 0, eps, vectors, graph, acceptOrds, visitedLimit); + results.clear(); + graphSearcher.searchLevel( + results, query, topK, 0, eps, vectors, graph, acceptOrds, visitedLimit); results.setVisitedCount(results.visitedCount() + numVisited); return results; } @@ -205,10 +209,19 @@ public NeighborQueue searchLevel( RandomAccessVectorValues vectors, HnswGraph graph) throws IOException { - return searchLevel(query, topK, level, eps, vectors, graph, null, Integer.MAX_VALUE); + NeighborQueue results = new NeighborQueue(topK, false); + searchLevel(results, query, topK, level, eps, vectors, graph, null, Integer.MAX_VALUE); + return results; } - private NeighborQueue searchLevel( + /** + * Add the closest neighbors found to a priority queue (heap). These are returned in REVERSE + * proximity order -- the most distant neighbor of the topK found, i.e. the one with the lowest + * score/comparison value, will be at the top of the heap, while the closest neighbor will be the + * last to be popped. + */ + private void searchLevel( + NeighborQueue results, T query, int topK, int level, @@ -219,7 +232,6 @@ private NeighborQueue searchLevel( int visitedLimit) throws IOException { int size = graph.size(); - NeighborQueue results = new NeighborQueue(topK, false); prepareScratchState(vectors.size()); int numVisited = 0; @@ -280,7 +292,6 @@ private NeighborQueue searchLevel( results.pop(); } results.setVisitedCount(numVisited); - return results; } private float compare(T query, RandomAccessVectorValues vectors, int ord) throws IOException {