Skip to content

Commit

Permalink
LUCENE-10040 Correct TestHnswGraph.testSearchWithAcceptOrds (#277)
Browse files Browse the repository at this point in the history
If we set numSeed = 10, this test fails sometimes  because it may mark
expected results docs (from 0 to 9) as deleted which don't end up
being retrieved, resulting in a low recall

- set numSeed to 10 to ensure 10 results are returned
- add startIndex paramenter to createRandomAcceptOrds that allows
  documents before startIndex to be NOT deleted
- use startIndex equal to 10 for createRandomAcceptOrds

Relates to #239
  • Loading branch information
mayya-sharipova authored Sep 6, 2021
1 parent 4df8d64 commit bc161e6
Showing 1 changed file with 30 additions and 15 deletions.
45 changes: 30 additions & 15 deletions lucene/core/src/test/org/apache/lucene/util/hnsw/TestHnswGraph.java
Original file line number Diff line number Diff line change
Expand Up @@ -136,25 +136,29 @@ public void testAknnDiverse() throws IOException {
HnswGraph.search(
new float[] {1, 0},
10,
5,
10,
vectors.randomAccess(),
VectorSimilarityFunction.DOT_PRODUCT,
hnsw,
null,
random());

int[] nodes = nn.nodes();
assertTrue("Number of found results is not equal to [10].", nodes.length == 10);
int sum = 0;
for (int node : nn.nodes()) {
for (int node : nodes) {
sum += node;
}
// We expect to get approximately 100% recall; the lowest docIds are closest to zero; sum(0,9) =
// 45
// We expect to get approximately 100% recall;
// the lowest docIds are closest to zero; sum(0,9) = 45
assertTrue("sum(result docs)=" + sum, sum < 75);

for (int i = 0; i < nDoc; i++) {
NeighborArray neighbors = hnsw.getNeighbors(i);
int[] nodes = neighbors.node;
int[] nnodes = neighbors.node;
for (int j = 0; j < neighbors.size(); j++) {
// all neighbors should be valid node ids.
assertTrue(nodes[j] < nDoc);
assertTrue(nnodes[j] < nDoc);
}
}
}
Expand All @@ -167,24 +171,27 @@ public void testSearchWithAcceptOrds() throws IOException {
vectors, VectorSimilarityFunction.DOT_PRODUCT, 16, 100, random().nextInt());
HnswGraph hnsw = builder.build(vectors);

Bits acceptOrds = createRandomAcceptOrds(vectors.size);
// the first 10 docs must not be deleted to ensure the expected recall
Bits acceptOrds = createRandomAcceptOrds(10, vectors.size);
NeighborQueue nn =
HnswGraph.search(
new float[] {1, 0},
10,
5,
10,
vectors.randomAccess(),
VectorSimilarityFunction.DOT_PRODUCT,
hnsw,
acceptOrds,
random());
int[] nodes = nn.nodes();
assertTrue("Number of found results is not equal to [10].", nodes.length == 10);
int sum = 0;
for (int node : nn.nodes()) {
for (int node : nodes) {
assertTrue("the results include a deleted document: " + node, acceptOrds.get(node));
sum += node;
}
// We expect to get approximately 100% recall; the lowest docIds are closest to zero; sum(0,9) =
// 45
// We expect to get approximately 100% recall;
// the lowest docIds are closest to zero; sum(0,9) = 45
assertTrue("sum(result docs)=" + sum, sum < 75);
}

Expand Down Expand Up @@ -311,7 +318,7 @@ public void testRandom() throws IOException {
HnswGraphBuilder builder =
new HnswGraphBuilder(vectors, similarityFunction, 10, 30, random().nextLong());
HnswGraph hnsw = builder.build(vectors);
Bits acceptOrds = random().nextBoolean() ? null : createRandomAcceptOrds(size);
Bits acceptOrds = random().nextBoolean() ? null : createRandomAcceptOrds(0, size);

int totalMatches = 0;
for (int i = 0; i < 100; i++) {
Expand Down Expand Up @@ -492,10 +499,18 @@ private static float[][] createRandomVectors(int size, int dimension, Random ran
}
}

/** Generate a random bitset where each entry has a 2/3 probability of being set. */
private static Bits createRandomAcceptOrds(int length) {
/**
* Generate a random bitset where before startIndex all bits are set, and after startIndex each
* entry has a 2/3 probability of being set.
*/
private static Bits createRandomAcceptOrds(int startIndex, int length) {
FixedBitSet bits = new FixedBitSet(length);
for (int i = 0; i < bits.length(); i++) {
// all bits are set before startIndex
for (int i = 0; i < startIndex; i++) {
bits.set(i);
}
// after startIndex, bits are set with 2/3 probability
for (int i = startIndex; i < bits.length(); i++) {
if (random().nextFloat() < 0.667f) {
bits.set(i);
}
Expand Down

0 comments on commit bc161e6

Please sign in to comment.