diff --git a/lucene/core/src/test/org/apache/lucene/util/hnsw/TestHnswGraph.java b/lucene/core/src/test/org/apache/lucene/util/hnsw/TestHnswGraph.java index bec0541a76ff..28dc5175bddc 100644 --- a/lucene/core/src/test/org/apache/lucene/util/hnsw/TestHnswGraph.java +++ b/lucene/core/src/test/org/apache/lucene/util/hnsw/TestHnswGraph.java @@ -136,25 +136,29 @@ public void testAknnDiverse() throws IOException { HnswGraph.search( new float[] {1, 0}, 10, - 5, + 10, vectors.randomAccess(), VectorSimilarityFunction.DOT_PRODUCT, hnsw, null, random()); + + int[] nodes = nn.nodes(); + assertTrue("Number of found results is not equal to [10].", nodes.length == 10); int sum = 0; - for (int node : nn.nodes()) { + for (int node : nodes) { sum += node; } - // We expect to get approximately 100% recall; the lowest docIds are closest to zero; sum(0,9) = - // 45 + // We expect to get approximately 100% recall; + // the lowest docIds are closest to zero; sum(0,9) = 45 assertTrue("sum(result docs)=" + sum, sum < 75); + for (int i = 0; i < nDoc; i++) { NeighborArray neighbors = hnsw.getNeighbors(i); - int[] nodes = neighbors.node; + int[] nnodes = neighbors.node; for (int j = 0; j < neighbors.size(); j++) { // all neighbors should be valid node ids. - assertTrue(nodes[j] < nDoc); + assertTrue(nnodes[j] < nDoc); } } } @@ -167,24 +171,27 @@ public void testSearchWithAcceptOrds() throws IOException { vectors, VectorSimilarityFunction.DOT_PRODUCT, 16, 100, random().nextInt()); HnswGraph hnsw = builder.build(vectors); - Bits acceptOrds = createRandomAcceptOrds(vectors.size); + // the first 10 docs must not be deleted to ensure the expected recall + Bits acceptOrds = createRandomAcceptOrds(10, vectors.size); NeighborQueue nn = HnswGraph.search( new float[] {1, 0}, 10, - 5, + 10, vectors.randomAccess(), VectorSimilarityFunction.DOT_PRODUCT, hnsw, acceptOrds, random()); + int[] nodes = nn.nodes(); + assertTrue("Number of found results is not equal to [10].", nodes.length == 10); int sum = 0; - for (int node : nn.nodes()) { + for (int node : nodes) { assertTrue("the results include a deleted document: " + node, acceptOrds.get(node)); sum += node; } - // We expect to get approximately 100% recall; the lowest docIds are closest to zero; sum(0,9) = - // 45 + // We expect to get approximately 100% recall; + // the lowest docIds are closest to zero; sum(0,9) = 45 assertTrue("sum(result docs)=" + sum, sum < 75); } @@ -311,7 +318,7 @@ public void testRandom() throws IOException { HnswGraphBuilder builder = new HnswGraphBuilder(vectors, similarityFunction, 10, 30, random().nextLong()); HnswGraph hnsw = builder.build(vectors); - Bits acceptOrds = random().nextBoolean() ? null : createRandomAcceptOrds(size); + Bits acceptOrds = random().nextBoolean() ? null : createRandomAcceptOrds(0, size); int totalMatches = 0; for (int i = 0; i < 100; i++) { @@ -492,10 +499,18 @@ private static float[][] createRandomVectors(int size, int dimension, Random ran } } - /** Generate a random bitset where each entry has a 2/3 probability of being set. */ - private static Bits createRandomAcceptOrds(int length) { + /** + * Generate a random bitset where before startIndex all bits are set, and after startIndex each + * entry has a 2/3 probability of being set. + */ + private static Bits createRandomAcceptOrds(int startIndex, int length) { FixedBitSet bits = new FixedBitSet(length); - for (int i = 0; i < bits.length(); i++) { + // all bits are set before startIndex + for (int i = 0; i < startIndex; i++) { + bits.set(i); + } + // after startIndex, bits are set with 2/3 probability + for (int i = startIndex; i < bits.length(); i++) { if (random().nextFloat() < 0.667f) { bits.set(i); }