Skip to content

Commit

Permalink
Use method overloading
Browse files Browse the repository at this point in the history
  • Loading branch information
zhaih committed May 9, 2023
1 parent 5520eea commit 2650d3a
Show file tree
Hide file tree
Showing 3 changed files with 85 additions and 150 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -66,9 +66,6 @@ public HnswGraphSearcher(
/**
* Searches HNSW graph for the nearest neighbors of a query vector.
*
* <p>Note: if you want to search {@link OnHeapHnswGraph} in a thread-safety manner, please
* consider using {@link OnHeapHnswGraphSearcher}
*
* @param query search query vector
* @param topK the number of nodes to be returned
* @param vectors the vector values
Expand Down Expand Up @@ -106,12 +103,33 @@ public static NeighborQueue search(
return search(query, topK, vectors, graph, graphSearcher, acceptOrds, visitedLimit);
}

/**
* Search {@link OnHeapHnswGraph}, this method is thread safe, for parameters please refer to
* {@link #search(float[], int, RandomAccessVectorValues, VectorEncoding,
* VectorSimilarityFunction, HnswGraph, Bits, int)}
*/
public static NeighborQueue search(
float[] query,
int topK,
RandomAccessVectorValues<float[]> vectors,
VectorEncoding vectorEncoding,
VectorSimilarityFunction similarityFunction,
OnHeapHnswGraph graph,
Bits acceptOrds,
int visitedLimit)
throws IOException {
OnHeapHnswGraphSearcher<float[]> graphSearcher =
new OnHeapHnswGraphSearcher<>(
vectorEncoding,
similarityFunction,
new NeighborQueue(topK, true),
new SparseFixedBitSet(vectors.size()));
return search(query, topK, vectors, graph, graphSearcher, acceptOrds, visitedLimit);
}

/**
* Searches HNSW graph for the nearest neighbors of a query vector.
*
* <p>Note: if you want to search {@link OnHeapHnswGraph} in a thread-safety manner, please
* consider using {@link OnHeapHnswGraphSearcher}
*
* @param query search query vector
* @param topK the number of nodes to be returned
* @param vectors the vector values
Expand Down Expand Up @@ -149,6 +167,30 @@ public static NeighborQueue search(
return search(query, topK, vectors, graph, graphSearcher, acceptOrds, visitedLimit);
}

/**
* Search {@link OnHeapHnswGraph}, this method is thread safe, for parameters please refer to
* {@link #search(byte[], int, RandomAccessVectorValues, VectorEncoding, VectorSimilarityFunction,
* HnswGraph, Bits, int)}
*/
public static NeighborQueue search(
byte[] query,
int topK,
RandomAccessVectorValues<byte[]> vectors,
VectorEncoding vectorEncoding,
VectorSimilarityFunction similarityFunction,
OnHeapHnswGraph graph,
Bits acceptOrds,
int visitedLimit)
throws IOException {
OnHeapHnswGraphSearcher<byte[]> graphSearcher =
new OnHeapHnswGraphSearcher<>(
vectorEncoding,
similarityFunction,
new NeighborQueue(topK, true),
new SparseFixedBitSet(vectors.size()));
return search(query, topK, vectors, graph, graphSearcher, acceptOrds, visitedLimit);
}

static <T> NeighborQueue search(
T query,
int topK,
Expand Down Expand Up @@ -321,4 +363,38 @@ void graphSeek(HnswGraph graph, int level, int targetNode) throws IOException {
int graphNextNeighbor(HnswGraph graph) throws IOException {
return graph.nextNeighbor();
}

/**
* This class allow {@link OnHeapHnswGraph} to be searched in a thread-safe manner.
*
* <p>Note the class itself is NOT thread safe, but since each search will create one new graph
* searcher the search method is thread safe.
*/
private static class OnHeapHnswGraphSearcher<C> extends HnswGraphSearcher<C> {

private NeighborArray cur;
private int upto;

private OnHeapHnswGraphSearcher(
VectorEncoding vectorEncoding,
VectorSimilarityFunction similarityFunction,
NeighborQueue candidates,
BitSet visited) {
super(vectorEncoding, similarityFunction, candidates, visited);
}

@Override
void graphSeek(HnswGraph graph, int level, int targetNode) {
cur = ((OnHeapHnswGraph) graph).getNeighbors(level, targetNode);
upto = -1;
}

@Override
int graphNextNeighbor(HnswGraph graph) {
if (++upto < cur.size()) {
return cur.node[upto];
}
return NO_MORE_DOCS;
}
}
}

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -962,7 +962,7 @@ public void testRandom() throws IOException {
assertTrue("overlap=" + overlap, overlap > 0.9);
}

/* test using multi-thread to search OnHeapHnswGraph with OnHeapHnswGraphSearcher */
/* test thread-safety of searching OnHeapHnswGraph */
@SuppressWarnings("unchecked")
public void testOnHeapHnswGraphSearch()
throws IOException, ExecutionException, InterruptedException, TimeoutException {
Expand Down Expand Up @@ -1021,7 +1021,7 @@ public void testOnHeapHnswGraphSearch()
try {
actual =
switch (getVectorEncoding()) {
case BYTE -> OnHeapHnswGraphSearcher.search(
case BYTE -> HnswGraphSearcher.search(
(byte[]) query,
100,
(RandomAccessVectorValues<byte[]>) vectors,
Expand All @@ -1030,7 +1030,7 @@ public void testOnHeapHnswGraphSearch()
hnsw,
acceptOrds,
Integer.MAX_VALUE);
case FLOAT32 -> OnHeapHnswGraphSearcher.search(
case FLOAT32 -> HnswGraphSearcher.search(
(float[]) query,
100,
(RandomAccessVectorValues<float[]>) vectors,
Expand Down

0 comments on commit 2650d3a

Please sign in to comment.