Skip to content

Commit

Permalink
Add multi-thread searchability to OnHeapHnswGraph (#12257)
Browse files Browse the repository at this point in the history
  • Loading branch information
zhaih committed May 22, 2023
1 parent b21b396 commit 25a908d
Show file tree
Hide file tree
Showing 3 changed files with 230 additions and 24 deletions.
3 changes: 2 additions & 1 deletion lucene/CHANGES.txt
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@ API Changes

New Features
---------------------
(No changes)

* GITHUB#12257: Create OnHeapHnswGraphSearcher to let OnHeapHnswGraph to be searched in a thread-safety manner. (Patrick Zhai)

Improvements
---------------------
Expand Down
145 changes: 122 additions & 23 deletions lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphSearcher.java
Original file line number Diff line number Diff line change
Expand Up @@ -100,28 +100,31 @@ public static NeighborQueue search(
similarityFunction,
new NeighborQueue(topK, true),
new SparseFixedBitSet(vectors.size()));
NeighborQueue results;
return search(query, topK, vectors, graph, graphSearcher, acceptOrds, visitedLimit);
}

int initialEp = graph.entryNode();
if (initialEp == -1) {
return new NeighborQueue(1, true);
}
int[] eps = new int[] {initialEp};
int numVisited = 0;
for (int level = graph.numLevels() - 1; level >= 1; level--) {
results = graphSearcher.searchLevel(query, 1, level, eps, vectors, graph, null, visitedLimit);
numVisited += results.visitedCount();
visitedLimit -= results.visitedCount();
if (results.incomplete()) {
results.setVisitedCount(numVisited);
return results;
}
eps[0] = results.pop();
}
results =
graphSearcher.searchLevel(query, topK, 0, eps, vectors, graph, acceptOrds, visitedLimit);
results.setVisitedCount(results.visitedCount() + numVisited);
return results;
/**
* Search {@link OnHeapHnswGraph}, this method is thread safe, for parameters please refer to
* {@link #search(float[], int, RandomAccessVectorValues, VectorEncoding,
* VectorSimilarityFunction, HnswGraph, Bits, int)}
*/
public static NeighborQueue search(
float[] query,
int topK,
RandomAccessVectorValues<float[]> vectors,
VectorEncoding vectorEncoding,
VectorSimilarityFunction similarityFunction,
OnHeapHnswGraph graph,
Bits acceptOrds,
int visitedLimit)
throws IOException {
OnHeapHnswGraphSearcher<float[]> graphSearcher =
new OnHeapHnswGraphSearcher<>(
vectorEncoding,
similarityFunction,
new NeighborQueue(topK, true),
new SparseFixedBitSet(vectors.size()));
return search(query, topK, vectors, graph, graphSearcher, acceptOrds, visitedLimit);
}

/**
Expand Down Expand Up @@ -161,6 +164,46 @@ public static NeighborQueue search(
similarityFunction,
new NeighborQueue(topK, true),
new SparseFixedBitSet(vectors.size()));
return search(query, topK, vectors, graph, graphSearcher, acceptOrds, visitedLimit);
}

/**
* Search {@link OnHeapHnswGraph}, this method is thread safe, for parameters please refer to
* {@link #search(byte[], int, RandomAccessVectorValues, VectorEncoding, VectorSimilarityFunction,
* HnswGraph, Bits, int)}
*/
public static NeighborQueue search(
byte[] query,
int topK,
RandomAccessVectorValues<byte[]> vectors,
VectorEncoding vectorEncoding,
VectorSimilarityFunction similarityFunction,
OnHeapHnswGraph graph,
Bits acceptOrds,
int visitedLimit)
throws IOException {
OnHeapHnswGraphSearcher<byte[]> graphSearcher =
new OnHeapHnswGraphSearcher<>(
vectorEncoding,
similarityFunction,
new NeighborQueue(topK, true),
new SparseFixedBitSet(vectors.size()));
return search(query, topK, vectors, graph, graphSearcher, acceptOrds, visitedLimit);
}

private static <T> NeighborQueue search(
T query,
int topK,
RandomAccessVectorValues<T> vectors,
HnswGraph graph,
HnswGraphSearcher<T> graphSearcher,
Bits acceptOrds,
int visitedLimit)
throws IOException {
int initialEp = graph.entryNode();
if (initialEp == -1) {
return new NeighborQueue(1, true);
}
NeighborQueue results;
int[] eps = new int[] {graph.entryNode()};
int numVisited = 0;
Expand Down Expand Up @@ -252,9 +295,9 @@ private NeighborQueue searchLevel(
}

int topCandidateNode = candidates.pop();
graph.seek(level, topCandidateNode);
graphSeek(graph, level, topCandidateNode);
int friendOrd;
while ((friendOrd = graph.nextNeighbor()) != NO_MORE_DOCS) {
while ((friendOrd = graphNextNeighbor(graph)) != NO_MORE_DOCS) {
assert friendOrd < size : "friendOrd=" + friendOrd + "; size=" + size;
if (visited.getAndSet(friendOrd)) {
continue;
Expand Down Expand Up @@ -298,4 +341,60 @@ private void prepareScratchState(int capacity) {
}
visited.clear(0, visited.length());
}

/**
* Seek a specific node in the given graph. The default implementation will just call {@link
* HnswGraph#seek(int, int)}
*
* @throws IOException when seeking the graph
*/
void graphSeek(HnswGraph graph, int level, int targetNode) throws IOException {
graph.seek(level, targetNode);
}

/**
* Get the next neighbor from the graph, you must call {@link #graphSeek(HnswGraph, int, int)}
* before calling this method. The default implementation will just call {@link
* HnswGraph#nextNeighbor()}
*
* @return see {@link HnswGraph#nextNeighbor()}
* @throws IOException when advance neighbors
*/
int graphNextNeighbor(HnswGraph graph) throws IOException {
return graph.nextNeighbor();
}

/**
* This class allow {@link OnHeapHnswGraph} to be searched in a thread-safe manner.
*
* <p>Note the class itself is NOT thread safe, but since each search will create one new graph
* searcher the search method is thread safe.
*/
private static class OnHeapHnswGraphSearcher<C> extends HnswGraphSearcher<C> {

private NeighborArray cur;
private int upto;

private OnHeapHnswGraphSearcher(
VectorEncoding vectorEncoding,
VectorSimilarityFunction similarityFunction,
NeighborQueue candidates,
BitSet visited) {
super(vectorEncoding, similarityFunction, candidates, visited);
}

@Override
void graphSeek(HnswGraph graph, int level, int targetNode) {
cur = ((OnHeapHnswGraph) graph).getNeighbors(level, targetNode);
upto = -1;
}

@Override
int graphNextNeighbor(HnswGraph graph) {
if (++upto < cur.size()) {
return cur.node[upto];
}
return NO_MORE_DOCS;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,12 @@
import java.util.Map;
import java.util.Random;
import java.util.Set;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.TimeoutException;
import java.util.stream.Collectors;
import org.apache.lucene.codecs.KnnVectorsFormat;
import org.apache.lucene.codecs.lucene95.Lucene95Codec;
Expand Down Expand Up @@ -66,6 +72,7 @@
import org.apache.lucene.util.BitSet;
import org.apache.lucene.util.Bits;
import org.apache.lucene.util.FixedBitSet;
import org.apache.lucene.util.NamedThreadFactory;
import org.apache.lucene.util.RamUsageEstimator;
import org.apache.lucene.util.VectorUtil;
import org.apache.lucene.util.hnsw.HnswGraph.NodesIterator;
Expand Down Expand Up @@ -990,6 +997,105 @@ public void testRandom() throws IOException {
assertTrue("overlap=" + overlap, overlap > 0.9);
}

/* test thread-safety of searching OnHeapHnswGraph */
@SuppressWarnings("unchecked")
public void testOnHeapHnswGraphSearch()
throws IOException, ExecutionException, InterruptedException, TimeoutException {
int size = atLeast(100);
int dim = atLeast(10);
AbstractMockVectorValues<T> vectors = vectorValues(size, dim);
int topK = 5;
HnswGraphBuilder<T> builder =
HnswGraphBuilder.create(
vectors, getVectorEncoding(), similarityFunction, 10, 30, random().nextLong());
OnHeapHnswGraph hnsw = builder.build(vectors.copy());
Bits acceptOrds = random().nextBoolean() ? null : createRandomAcceptOrds(0, size);

List<T> queries = new ArrayList<>();
List<NeighborQueue> expects = new ArrayList<>();
for (int i = 0; i < 100; i++) {
NeighborQueue expect;
T query = randomVector(dim);
queries.add(query);
expect =
switch (getVectorEncoding()) {
case BYTE -> HnswGraphSearcher.search(
(byte[]) query,
100,
(RandomAccessVectorValues<byte[]>) vectors,
getVectorEncoding(),
similarityFunction,
hnsw,
acceptOrds,
Integer.MAX_VALUE);
case FLOAT32 -> HnswGraphSearcher.search(
(float[]) query,
100,
(RandomAccessVectorValues<float[]>) vectors,
getVectorEncoding(),
similarityFunction,
hnsw,
acceptOrds,
Integer.MAX_VALUE);
};

while (expect.size() > topK) {
expect.pop();
}
expects.add(expect);
}

ExecutorService exec =
Executors.newFixedThreadPool(4, new NamedThreadFactory("onHeapHnswSearch"));
List<Future<NeighborQueue>> futures = new ArrayList<>();
for (T query : queries) {
futures.add(
exec.submit(
() -> {
NeighborQueue actual;
try {
actual =
switch (getVectorEncoding()) {
case BYTE -> HnswGraphSearcher.search(
(byte[]) query,
100,
(RandomAccessVectorValues<byte[]>) vectors,
getVectorEncoding(),
similarityFunction,
hnsw,
acceptOrds,
Integer.MAX_VALUE);
case FLOAT32 -> HnswGraphSearcher.search(
(float[]) query,
100,
(RandomAccessVectorValues<float[]>) vectors,
getVectorEncoding(),
similarityFunction,
hnsw,
acceptOrds,
Integer.MAX_VALUE);
};
} catch (IOException ioe) {
throw new RuntimeException(ioe);
}
while (actual.size() > topK) {
actual.pop();
}
return actual;
}));
}
List<NeighborQueue> actuals = new ArrayList<>();
for (Future<NeighborQueue> future : futures) {
actuals.add(future.get(10, TimeUnit.SECONDS));
}
exec.shutdownNow();
for (int i = 0; i < expects.size(); i++) {
NeighborQueue expect = expects.get(i);
NeighborQueue actual = actuals.get(i);
assertArrayEquals(expect.nodes(), actual.nodes());
}
}

private int computeOverlap(int[] a, int[] b) {
Arrays.sort(a);
Arrays.sort(b);
Expand Down

0 comments on commit 25a908d

Please sign in to comment.