diff --git a/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphSearcher.java b/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphSearcher.java index 84d40c841bf8..4857d5b9d577 100644 --- a/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphSearcher.java +++ b/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphSearcher.java @@ -100,7 +100,7 @@ public static NeighborQueue search( similarityFunction, new NeighborQueue(topK, true), new SparseFixedBitSet(vectors.size())); - NeighborQueue results = new NeighborQueue(topK, false); + NeighborQueue results; int initialEp = graph.entryNode(); if (initialEp == -1) { @@ -109,8 +109,7 @@ public static NeighborQueue search( int[] eps = new int[] {initialEp}; int numVisited = 0; for (int level = graph.numLevels() - 1; level >= 1; level--) { - results.clear(); - graphSearcher.searchLevel(results, query, 1, level, eps, vectors, graph, null, visitedLimit); + results = graphSearcher.searchLevel(query, 1, level, eps, vectors, graph, null, visitedLimit); numVisited += results.visitedCount(); visitedLimit -= results.visitedCount(); if (results.incomplete()) { @@ -119,9 +118,8 @@ public static NeighborQueue search( } eps[0] = results.pop(); } - results.clear(); - graphSearcher.searchLevel( - results, query, topK, 0, eps, vectors, graph, acceptOrds, visitedLimit); + results = + graphSearcher.searchLevel(query, topK, 0, eps, vectors, graph, acceptOrds, visitedLimit); results.setVisitedCount(results.visitedCount() + numVisited); return results; } @@ -163,12 +161,11 @@ public static NeighborQueue search( similarityFunction, new NeighborQueue(topK, true), new SparseFixedBitSet(vectors.size())); - NeighborQueue results = new NeighborQueue(topK, false); + NeighborQueue results; int[] eps = new int[] {graph.entryNode()}; int numVisited = 0; for (int level = graph.numLevels() - 1; level >= 1; level--) { - results.clear(); - graphSearcher.searchLevel(results, query, 1, level, eps, vectors, graph, null, visitedLimit); + results = graphSearcher.searchLevel(query, 1, level, eps, vectors, graph, null, visitedLimit); numVisited += results.visitedCount(); visitedLimit -= results.visitedCount(); @@ -179,9 +176,8 @@ public static NeighborQueue search( } eps[0] = results.pop(); } - results.clear(); - graphSearcher.searchLevel( - results, query, topK, 0, eps, vectors, graph, acceptOrds, visitedLimit); + results = + graphSearcher.searchLevel(query, topK, 0, eps, vectors, graph, acceptOrds, visitedLimit); results.setVisitedCount(results.visitedCount() + numVisited); return results; } @@ -209,19 +205,10 @@ public NeighborQueue searchLevel( RandomAccessVectorValues vectors, HnswGraph graph) throws IOException { - NeighborQueue results = new NeighborQueue(topK, false); - searchLevel(results, query, topK, level, eps, vectors, graph, null, Integer.MAX_VALUE); - return results; + return searchLevel(query, topK, level, eps, vectors, graph, null, Integer.MAX_VALUE); } - /** - * Add the closest neighbors found to a priority queue (heap). These are returned in REVERSE - * proximity order -- the most distant neighbor of the topK found, i.e. the one with the lowest - * score/comparison value, will be at the top of the heap, while the closest neighbor will be the - * last to be popped. - */ - private void searchLevel( - NeighborQueue results, + private NeighborQueue searchLevel( T query, int topK, int level, @@ -232,6 +219,7 @@ private void searchLevel( int visitedLimit) throws IOException { int size = graph.size(); + NeighborQueue results = new NeighborQueue(topK, false); prepareScratchState(vectors.size()); int numVisited = 0; @@ -292,6 +280,7 @@ private void searchLevel( results.pop(); } results.setVisitedCount(numVisited); + return results; } private float compare(T query, RandomAccessVectorValues vectors, int ord) throws IOException {