Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

allocate one NeighborQueue per search for results #12255

Merged
merged 1 commit into from
May 8, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ public static NeighborQueue search(
similarityFunction,
new NeighborQueue(topK, true),
new SparseFixedBitSet(vectors.size()));
NeighborQueue results;
NeighborQueue results = new NeighborQueue(topK, false);

int initialEp = graph.entryNode();
if (initialEp == -1) {
Expand All @@ -109,7 +109,8 @@ public static NeighborQueue search(
int[] eps = new int[] {initialEp};
int numVisited = 0;
for (int level = graph.numLevels() - 1; level >= 1; level--) {
results = graphSearcher.searchLevel(query, 1, level, eps, vectors, graph, null, visitedLimit);
results.clear();
graphSearcher.searchLevel(results, query, 1, level, eps, vectors, graph, null, visitedLimit);
numVisited += results.visitedCount();
visitedLimit -= results.visitedCount();
if (results.incomplete()) {
Expand All @@ -118,8 +119,9 @@ public static NeighborQueue search(
}
eps[0] = results.pop();
}
results =
graphSearcher.searchLevel(query, topK, 0, eps, vectors, graph, acceptOrds, visitedLimit);
results.clear();
graphSearcher.searchLevel(
results, query, topK, 0, eps, vectors, graph, acceptOrds, visitedLimit);
results.setVisitedCount(results.visitedCount() + numVisited);
return results;
}
Expand Down Expand Up @@ -161,11 +163,12 @@ public static NeighborQueue search(
similarityFunction,
new NeighborQueue(topK, true),
new SparseFixedBitSet(vectors.size()));
NeighborQueue results;
NeighborQueue results = new NeighborQueue(topK, false);
int[] eps = new int[] {graph.entryNode()};
int numVisited = 0;
for (int level = graph.numLevels() - 1; level >= 1; level--) {
results = graphSearcher.searchLevel(query, 1, level, eps, vectors, graph, null, visitedLimit);
results.clear();
graphSearcher.searchLevel(results, query, 1, level, eps, vectors, graph, null, visitedLimit);

numVisited += results.visitedCount();
visitedLimit -= results.visitedCount();
Expand All @@ -176,8 +179,9 @@ public static NeighborQueue search(
}
eps[0] = results.pop();
}
results =
graphSearcher.searchLevel(query, topK, 0, eps, vectors, graph, acceptOrds, visitedLimit);
results.clear();
graphSearcher.searchLevel(
results, query, topK, 0, eps, vectors, graph, acceptOrds, visitedLimit);
results.setVisitedCount(results.visitedCount() + numVisited);
return results;
}
Expand Down Expand Up @@ -205,10 +209,19 @@ public NeighborQueue searchLevel(
RandomAccessVectorValues<T> vectors,
HnswGraph graph)
throws IOException {
return searchLevel(query, topK, level, eps, vectors, graph, null, Integer.MAX_VALUE);
NeighborQueue results = new NeighborQueue(topK, false);
searchLevel(results, query, topK, level, eps, vectors, graph, null, Integer.MAX_VALUE);
return results;
}

private NeighborQueue searchLevel(
/**
* Add the closest neighbors found to a priority queue (heap). These are returned in REVERSE
* proximity order -- the most distant neighbor of the topK found, i.e. the one with the lowest
* score/comparison value, will be at the top of the heap, while the closest neighbor will be the
* last to be popped.
*/
private void searchLevel(
NeighborQueue results,
T query,
int topK,
int level,
Expand All @@ -219,7 +232,6 @@ private NeighborQueue searchLevel(
int visitedLimit)
throws IOException {
int size = graph.size();
NeighborQueue results = new NeighborQueue(topK, false);
prepareScratchState(vectors.size());

int numVisited = 0;
Expand Down Expand Up @@ -280,7 +292,6 @@ private NeighborQueue searchLevel(
results.pop();
}
results.setVisitedCount(numVisited);
return results;
}

private float compare(T query, RandomAccessVectorValues<T> vectors, int ord) throws IOException {
Expand Down