Skip to content

Commit

Permalink
gh-12627: HnswGraphBuilder connects disconnected HNSW graph components (
Browse files Browse the repository at this point in the history
  • Loading branch information
msokolov authored and Michael Sokolov committed Aug 9, 2024
1 parent a546ed5 commit 98fbbe2
Show file tree
Hide file tree
Showing 12 changed files with 723 additions and 141 deletions.
2 changes: 2 additions & 0 deletions lucene/CHANGES.txt
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,8 @@ Improvements

* GITHUB#13633: Add ability to read/write knn vector values to a MemoryIndex. (Ben Trent)

* GITHUB#12627: patch HNSW graphs to improve reachability of all nodes from entry points

Optimizations
---------------------

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -460,6 +460,7 @@ public void seek(int level, int targetOrd) throws IOException {
// unsafe; no bounds checking
dataIn.seek(graphLevelNodeOffsets.get(targetIndex + graphLevelNodeIndexOffsets[level]));
arcCount = dataIn.readVInt();
assert arcCount <= currentNeighborsBuffer.length : "too many neighbors: " + arcCount;
if (arcCount > 0) {
currentNeighborsBuffer[0] = dataIn.readVInt();
for (int i = 1; i < arcCount; i++) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -633,7 +633,7 @@ public T copyValue(T vectorValue) {
throw new UnsupportedOperationException();
}

OnHeapHnswGraph getGraph() {
OnHeapHnswGraph getGraph() throws IOException {
assert flatFieldVectorsWriter.isFinished();
if (node > 0) {
return hnswGraphBuilder.getGraph();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,4 +41,12 @@ public interface HnswBuilder {
void setInfoStream(InfoStream infoStream);

OnHeapHnswGraph getGraph();

/**
* Once this method is called no further updates to the graph are accepted (addGraphNode will
* throw IllegalStateException). Final modifications to the graph (eg patching up disconnected
* components, re-ordering node ids for better delta compression) may be triggered, so callers
* should expect this call to take some time.
*/
OnHeapHnswGraph getCompletedGraph() throws IOException;
}
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ public class HnswConcurrentMergeBuilder implements HnswBuilder {
private final ConcurrentMergeWorker[] workers;
private final HnswLock hnswLock;
private InfoStream infoStream = InfoStream.getDefault();
private boolean frozen;

public HnswConcurrentMergeBuilder(
TaskExecutor taskExecutor,
Expand Down Expand Up @@ -87,7 +88,9 @@ public OnHeapHnswGraph build(int maxOrd) throws IOException {
});
}
taskExecutor.invokeAll(futures);
return workers[0].getGraph();
finish();
frozen = true;
return workers[0].getCompletedGraph();
}

@Override
Expand All @@ -103,6 +106,20 @@ public void setInfoStream(InfoStream infoStream) {
}
}

@Override
public OnHeapHnswGraph getCompletedGraph() throws IOException {
if (frozen == false) {
// should already have been called in build(), but just in case
finish();
frozen = true;
}
return getGraph();
}

private void finish() throws IOException {
workers[0].finish();
}

@Override
public OnHeapHnswGraph getGraph() {
return workers[0].getGraph();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,11 @@
package org.apache.lucene.util.hnsw;

import static java.lang.Math.log;
import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;

import java.io.IOException;
import java.util.Comparator;
import java.util.List;
import java.util.Locale;
import java.util.Objects;
import java.util.SplittableRandom;
Expand All @@ -28,6 +31,7 @@
import org.apache.lucene.search.TopDocs;
import org.apache.lucene.util.FixedBitSet;
import org.apache.lucene.util.InfoStream;
import org.apache.lucene.util.hnsw.HnswUtil.Component;

/**
* Builder for HNSW graph. See {@link HnswGraph} for a gloss on the algorithm and the meaning of the
Expand Down Expand Up @@ -66,6 +70,7 @@ public class HnswGraphBuilder implements HnswBuilder {
protected final HnswLock hnswLock;

private InfoStream infoStream = InfoStream.getDefault();
private boolean frozen;

public static HnswGraphBuilder create(
RandomVectorScorerSupplier scorerSupplier, int M, int beamWidth, long seed)
Expand Down Expand Up @@ -136,7 +141,7 @@ protected HnswGraphBuilder(
HnswGraphSearcher graphSearcher)
throws IOException {
if (M <= 0) {
throw new IllegalArgumentException("maxConn must be positive");
throw new IllegalArgumentException("M (max connections) must be positive");
}
if (beamWidth <= 0) {
throw new IllegalArgumentException("beamWidth must be positive");
Expand Down Expand Up @@ -168,6 +173,15 @@ public void setInfoStream(InfoStream infoStream) {
this.infoStream = infoStream;
}

@Override
public OnHeapHnswGraph getCompletedGraph() throws IOException {
if (!frozen) {
finish();
frozen = true;
}
return getGraph();
}

@Override
public OnHeapHnswGraph getGraph() {
return hnsw;
Expand Down Expand Up @@ -389,6 +403,93 @@ private static int getRandomGraphLevel(double ml, SplittableRandom random) {
return ((int) (-log(randDouble) * ml));
}

void finish() throws IOException {
connectComponents();
}

private void connectComponents() throws IOException {
long start = System.nanoTime();
for (int level = 0; level < hnsw.numLevels(); level++) {
if (connectComponents(level) == false) {
if (infoStream.isEnabled(HNSW_COMPONENT)) {
infoStream.message(HNSW_COMPONENT, "connectComponents failed on level " + level);
}
}
}
if (infoStream.isEnabled(HNSW_COMPONENT)) {
infoStream.message(
HNSW_COMPONENT, "connectComponents " + (System.nanoTime() - start) / 1_000_000 + " ms");
}
}

private boolean connectComponents(int level) throws IOException {
FixedBitSet notFullyConnected = new FixedBitSet(hnsw.size());
int maxConn = M;
if (level == 0) {
maxConn *= 2;
}
List<Component> components = HnswUtil.components(hnsw, level, notFullyConnected, maxConn);
boolean result = true;
if (components.size() > 1) {
// connect other components to the largest one
Component c0 = components.stream().max(Comparator.comparingInt(Component::size)).get();
if (c0.start() == NO_MORE_DOCS) {
// the component is already fully connected - no room for new connections
return false;
}
// try for more connections? We only do one since otherwise they may become full
// while linking
GraphBuilderKnnCollector beam = new GraphBuilderKnnCollector(1);
int[] eps = new int[1];
for (Component c : components) {
if (c != c0) {
beam.clear();
eps[0] = c0.start();
RandomVectorScorer scorer = scorerSupplier.scorer(c.start());
// find the closest node in the largest component to the lowest-numbered node in this
// component that has room to make a connection
graphSearcher.searchLevel(beam, scorer, 0, eps, hnsw, notFullyConnected);
boolean linked = false;
while (beam.size() > 0) {
float score = beam.minimumScore();
int c0node = beam.popNode();
assert notFullyConnected.get(c0node);
// link the nodes
link(level, c0node, c.start(), score, notFullyConnected);
linked = true;
}
if (!linked) {
result = false;
}
}
}
}
return result;
}

// Try to link two nodes bidirectionally; the forward connection will always be made.
// Update notFullyConnected.
private void link(int level, int n0, int n1, float score, FixedBitSet notFullyConnected) {
NeighborArray nbr0 = hnsw.getNeighbors(level, n0);
NeighborArray nbr1 = hnsw.getNeighbors(level, n1);
// must subtract 1 here since the nodes array is one larger than the configured
// max neighbors (M / 2M).
// We should have taken care of this check by searching for not-full nodes
int maxConn = nbr0.nodes().length - 1;
assert notFullyConnected.get(n0);
assert nbr0.size() < maxConn : "node " + n0 + " is full, has " + nbr0.size() + " friends";
nbr0.addOutOfOrder(n1, score);
if (nbr0.size() == maxConn) {
notFullyConnected.clear(n0);
}
if (nbr1.size() < maxConn) {
nbr1.addOutOfOrder(n0, score);
if (nbr1.size() == maxConn) {
notFullyConnected.clear(n1);
}
}
}

/**
* A restricted, specialized knnCollector that can be used when building a graph.
*
Expand Down
Loading

0 comments on commit 98fbbe2

Please sign in to comment.