diff --git a/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene91/Lucene91HnswVectorsWriter.java b/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene91/Lucene91HnswVectorsWriter.java index 68c8967b9b28..9281b6374411 100644 --- a/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene91/Lucene91HnswVectorsWriter.java +++ b/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene91/Lucene91HnswVectorsWriter.java @@ -25,6 +25,7 @@ import java.util.Arrays; import org.apache.lucene.codecs.BufferingKnnVectorsWriter; import org.apache.lucene.codecs.CodecUtil; +import org.apache.lucene.codecs.lucene95.Lucene95HnswVectorsWriter; import org.apache.lucene.index.ByteVectorValues; import org.apache.lucene.index.DocsWithFieldSet; import org.apache.lucene.index.FieldInfo; @@ -36,7 +37,6 @@ import org.apache.lucene.store.IndexInput; import org.apache.lucene.store.IndexOutput; import org.apache.lucene.util.IOUtils; -import org.apache.lucene.util.hnsw.HnswGraph.NodesIterator; import org.apache.lucene.util.hnsw.RandomAccessVectorValues; /** @@ -227,11 +227,10 @@ private void writeMeta( } else { meta.writeInt(graph.numLevels()); for (int level = 0; level < graph.numLevels(); level++) { - NodesIterator nodesOnLevel = graph.getNodesOnLevel(level); - meta.writeInt(nodesOnLevel.size()); // number of nodes on a level + int[] sortedNodes = Lucene95HnswVectorsWriter.getSortedNodes(graph.getNodesOnLevel(level)); + meta.writeInt(sortedNodes.length); // number of nodes on a level if (level > 0) { - while (nodesOnLevel.hasNext()) { - int node = nodesOnLevel.nextInt(); + for (int node : sortedNodes) { meta.writeInt(node); // list of nodes on a level } } @@ -257,9 +256,8 @@ private Lucene91OnHeapHnswGraph writeGraph( // write vectors' neighbours on each level into the vectorIndex file int countOnLevel0 = graph.size(); for (int level = 0; level < graph.numLevels(); level++) { - NodesIterator nodesOnLevel = graph.getNodesOnLevel(level); - while (nodesOnLevel.hasNext()) { - int node = nodesOnLevel.nextInt(); + int[] sortedNodes = Lucene95HnswVectorsWriter.getSortedNodes(graph.getNodesOnLevel(level)); + for (int node : sortedNodes) { Lucene91NeighborArray neighbors = graph.getNeighbors(level, node); int size = neighbors.size(); vectorIndex.writeInt(size); diff --git a/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene92/Lucene92HnswVectorsWriter.java b/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene92/Lucene92HnswVectorsWriter.java index 1480d1aea2e4..b2e7629aed1c 100644 --- a/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene92/Lucene92HnswVectorsWriter.java +++ b/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene92/Lucene92HnswVectorsWriter.java @@ -27,6 +27,7 @@ import org.apache.lucene.codecs.BufferingKnnVectorsWriter; import org.apache.lucene.codecs.CodecUtil; import org.apache.lucene.codecs.lucene90.IndexedDISI; +import org.apache.lucene.codecs.lucene95.Lucene95HnswVectorsWriter; import org.apache.lucene.index.ByteVectorValues; import org.apache.lucene.index.DocsWithFieldSet; import org.apache.lucene.index.FieldInfo; @@ -39,7 +40,6 @@ import org.apache.lucene.store.IndexInput; import org.apache.lucene.store.IndexOutput; import org.apache.lucene.util.IOUtils; -import org.apache.lucene.util.hnsw.HnswGraph.NodesIterator; import org.apache.lucene.util.hnsw.HnswGraphBuilder; import org.apache.lucene.util.hnsw.NeighborArray; import org.apache.lucene.util.hnsw.OnHeapHnswGraph; @@ -261,11 +261,10 @@ private void writeMeta( } else { meta.writeInt(graph.numLevels()); for (int level = 0; level < graph.numLevels(); level++) { - NodesIterator nodesOnLevel = graph.getNodesOnLevel(level); - meta.writeInt(nodesOnLevel.size()); // number of nodes on a level + int[] sortedNodes = Lucene95HnswVectorsWriter.getSortedNodes(graph.getNodesOnLevel(level)); + meta.writeInt(sortedNodes.length); // number of nodes on a level if (level > 0) { - while (nodesOnLevel.hasNext()) { - int node = nodesOnLevel.nextInt(); + for (int node : sortedNodes) { meta.writeInt(node); // list of nodes on a level } } @@ -293,9 +292,8 @@ private OnHeapHnswGraph writeGraph( int countOnLevel0 = graph.size(); for (int level = 0; level < graph.numLevels(); level++) { int maxConnOnLevel = level == 0 ? (M * 2) : M; - NodesIterator nodesOnLevel = graph.getNodesOnLevel(level); - while (nodesOnLevel.hasNext()) { - int node = nodesOnLevel.nextInt(); + int[] sortedNodes = Lucene95HnswVectorsWriter.getSortedNodes(graph.getNodesOnLevel(level)); + for (int node : sortedNodes) { NeighborArray neighbors = graph.getNeighbors(level, node); int size = neighbors.size(); vectorIndex.writeInt(size); diff --git a/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene94/Lucene94HnswVectorsWriter.java b/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene94/Lucene94HnswVectorsWriter.java index f6f378027603..9a2a156f98ac 100644 --- a/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene94/Lucene94HnswVectorsWriter.java +++ b/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene94/Lucene94HnswVectorsWriter.java @@ -30,6 +30,7 @@ import org.apache.lucene.codecs.KnnFieldVectorsWriter; import org.apache.lucene.codecs.KnnVectorsWriter; import org.apache.lucene.codecs.lucene90.IndexedDISI; +import org.apache.lucene.codecs.lucene95.Lucene95HnswVectorsWriter; import org.apache.lucene.index.ByteVectorValues; import org.apache.lucene.index.DocsWithFieldSet; import org.apache.lucene.index.FieldInfo; @@ -303,9 +304,8 @@ private HnswGraph reconstructAndWriteGraph( for (int level = 1; level < graph.numLevels(); level++) { NodesIterator nodesOnLevel = graph.getNodesOnLevel(level); int[] newNodes = new int[nodesOnLevel.size()]; - int n = 0; - while (nodesOnLevel.hasNext()) { - newNodes[n++] = oldToNewMap[nodesOnLevel.nextInt()]; + for (int n = 0; nodesOnLevel.hasNext(); n++) { + newNodes[n] = oldToNewMap[nodesOnLevel.nextInt()]; } Arrays.sort(newNodes); nodesByLevel.add(newNodes); @@ -481,9 +481,8 @@ private void writeGraph(OnHeapHnswGraph graph) throws IOException { int countOnLevel0 = graph.size(); for (int level = 0; level < graph.numLevels(); level++) { int maxConnOnLevel = level == 0 ? (M * 2) : M; - NodesIterator nodesOnLevel = graph.getNodesOnLevel(level); - while (nodesOnLevel.hasNext()) { - int node = nodesOnLevel.nextInt(); + int[] sortedNodes = Lucene95HnswVectorsWriter.getSortedNodes(graph.getNodesOnLevel(level)); + for (int node : sortedNodes) { NeighborArray neighbors = graph.getNeighbors(level, node); int size = neighbors.size(); vectorIndex.writeInt(size); @@ -570,11 +569,10 @@ private void writeMeta( } else { meta.writeInt(graph.numLevels()); for (int level = 0; level < graph.numLevels(); level++) { - NodesIterator nodesOnLevel = graph.getNodesOnLevel(level); - meta.writeInt(nodesOnLevel.size()); // number of nodes on a level + int[] sortedNodes = Lucene95HnswVectorsWriter.getSortedNodes(graph.getNodesOnLevel(level)); + meta.writeInt(sortedNodes.length); // number of nodes on a level if (level > 0) { - while (nodesOnLevel.hasNext()) { - int node = nodesOnLevel.nextInt(); + for (int node : sortedNodes) { meta.writeInt(node); // list of nodes on a level } } diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene95/Lucene95HnswVectorsWriter.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene95/Lucene95HnswVectorsWriter.java index bf0b79807f06..5358d66f16e2 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/lucene95/Lucene95HnswVectorsWriter.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene95/Lucene95HnswVectorsWriter.java @@ -315,9 +315,8 @@ private HnswGraph reconstructAndWriteGraph( for (int level = 1; level < graph.numLevels(); level++) { NodesIterator nodesOnLevel = graph.getNodesOnLevel(level); int[] newNodes = new int[nodesOnLevel.size()]; - int n = 0; - while (nodesOnLevel.hasNext()) { - newNodes[n++] = oldToNewMap[nodesOnLevel.nextInt()]; + for (int n = 0; nodesOnLevel.hasNext(); n++) { + newNodes[n] = oldToNewMap[nodesOnLevel.nextInt()]; } Arrays.sort(newNodes); nodesByLevel.add(newNodes); @@ -677,11 +676,10 @@ private int[][] writeGraph(OnHeapHnswGraph graph) throws IOException { int countOnLevel0 = graph.size(); int[][] offsets = new int[graph.numLevels()][]; for (int level = 0; level < graph.numLevels(); level++) { - NodesIterator nodesOnLevel = graph.getNodesOnLevel(level); - offsets[level] = new int[nodesOnLevel.size()]; + int[] sortedNodes = getSortedNodes(graph.getNodesOnLevel(level)); + offsets[level] = new int[sortedNodes.length]; int nodeOffsetId = 0; - while (nodesOnLevel.hasNext()) { - int node = nodesOnLevel.nextInt(); + for (int node : sortedNodes) { NeighborArray neighbors = graph.getNeighbors(level, node); int size = neighbors.size(); // Write size in VInt as the neighbors list is typically small @@ -706,6 +704,15 @@ private int[][] writeGraph(OnHeapHnswGraph graph) throws IOException { return offsets; } + public static int[] getSortedNodes(NodesIterator nodesOnLevel) { + int[] sortedNodes = new int[nodesOnLevel.size()]; + for (int n = 0; nodesOnLevel.hasNext(); n++) { + sortedNodes[n] = nodesOnLevel.nextInt(); + } + Arrays.sort(sortedNodes); + return sortedNodes; + } + private void writeMeta( FieldInfo field, int maxDoc, @@ -779,6 +786,7 @@ private void writeMeta( if (level > 0) { int[] nol = new int[nodesOnLevel.size()]; int numberConsumed = nodesOnLevel.consume(nol); + Arrays.sort(nol); assert numberConsumed == nodesOnLevel.size(); meta.writeVInt(nol.length); // number of nodes on a level for (int i = nodesOnLevel.size() - 1; i > 0; --i) { diff --git a/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraph.java b/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraph.java index 9086ab55d2eb..9b3d0d62c905 100644 --- a/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraph.java +++ b/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraph.java @@ -81,7 +81,8 @@ protected HnswGraph() {} public abstract int entryNode() throws IOException; /** - * Get all nodes on a given level as node 0th ordinals + * Get all nodes on a given level as node 0th ordinals. The nodes are NOT guaranteed to be + * presented in any particular order. * * @param level level for which to get all nodes * @return an iterator over nodes where {@code nextInt} returns a next node on the level @@ -123,7 +124,8 @@ public NodesIterator getNodesOnLevel(int level) { /** * Iterator over the graph nodes on a certain level, Iterator also provides the size – the total - * number of nodes to be iterated over. + * number of nodes to be iterated over. The nodes are NOT guaranteed to be presented in any + * particular order. */ public abstract static class NodesIterator implements PrimitiveIterator.OfInt { protected final int size; diff --git a/lucene/core/src/java/org/apache/lucene/util/hnsw/OnHeapHnswGraph.java b/lucene/core/src/java/org/apache/lucene/util/hnsw/OnHeapHnswGraph.java index 9862536de08c..ae39614f160e 100644 --- a/lucene/core/src/java/org/apache/lucene/util/hnsw/OnHeapHnswGraph.java +++ b/lucene/core/src/java/org/apache/lucene/util/hnsw/OnHeapHnswGraph.java @@ -20,8 +20,9 @@ import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS; import java.util.ArrayList; +import java.util.HashMap; import java.util.List; -import java.util.TreeMap; +import java.util.Map; import org.apache.lucene.util.Accountable; import org.apache.lucene.util.RamUsageEstimator; @@ -40,12 +41,12 @@ public final class OnHeapHnswGraph extends HnswGraph implements Accountable { // added to HnswBuilder, and the node values are the ordinals of those vectors. // Thus, on all levels, neighbors expressed as the level 0's nodes' ordinals. private final List graphLevel0; - // Represents levels 1-N. Each level is represented with a TreeMap that maps a levels level 0 + // Represents levels 1-N. Each level is represented with a Map that maps a levels level 0 // ordinal to its neighbors on that level. All nodes are in level 0, so we do not need to maintain // it in this list. However, to avoid changing list indexing, we always will make the first // element // null. - private final List> graphUpperLevels; + private final List> graphUpperLevels; private final int nsize; private final int nsize0; @@ -76,7 +77,7 @@ public NeighborArray getNeighbors(int level, int node) { if (level == 0) { return graphLevel0.get(node); } - TreeMap levelMap = graphUpperLevels.get(level); + Map levelMap = graphUpperLevels.get(level); assert levelMap.containsKey(node); return levelMap.get(node); } @@ -103,7 +104,7 @@ public void addNode(int level, int node) { // and make this node the graph's new entry point if (level >= numLevels) { for (int i = numLevels; i <= level; i++) { - graphUpperLevels.add(new TreeMap<>()); + graphUpperLevels.add(new HashMap<>()); } numLevels = level + 1; entryNode = node; @@ -204,4 +205,15 @@ public long ramBytesUsed() { } return total; } + + @Override + public String toString() { + return "OnHeapHnswGraph(size=" + + size() + + ", numLevels=" + + numLevels + + ", entryNode=" + + entryNode + + ")"; + } } diff --git a/lucene/core/src/test/org/apache/lucene/util/hnsw/HnswGraphTestCase.java b/lucene/core/src/test/org/apache/lucene/util/hnsw/HnswGraphTestCase.java index 80c9c7a93cf4..9825d4a5f419 100644 --- a/lucene/core/src/test/org/apache/lucene/util/hnsw/HnswGraphTestCase.java +++ b/lucene/core/src/test/org/apache/lucene/util/hnsw/HnswGraphTestCase.java @@ -29,6 +29,7 @@ import java.util.HashMap; import java.util.HashSet; import java.util.List; +import java.util.Locale; import java.util.Map; import java.util.Random; import java.util.Set; @@ -265,19 +266,50 @@ public KnnVectorsFormat getKnnVectorsFormatForField(String field) { } } + List sortedNodesOnLevel(HnswGraph h, int level) throws IOException { + NodesIterator nodesOnLevel = h.getNodesOnLevel(level); + List nodes = new ArrayList<>(); + while (nodesOnLevel.hasNext()) { + nodes.add(nodesOnLevel.next()); + } + Collections.sort(nodes); + return nodes; + } + void assertGraphEqual(HnswGraph g, HnswGraph h) throws IOException { - assertEquals("the number of levels in the graphs are different!", g.numLevels(), h.numLevels()); - assertEquals("the number of nodes in the graphs are different!", g.size(), h.size()); + // construct these up front since they call seek which will mess up our test loop + String prettyG = prettyPrint(g); + String prettyH = prettyPrint(h); + assertEquals( + String.format( + Locale.ROOT, + "the number of levels in the graphs are different:%n%s%n%s", + prettyG, + prettyH), + g.numLevels(), + h.numLevels()); + assertEquals( + String.format( + Locale.ROOT, + "the number of nodes in the graphs are different:%n%s%n%s", + prettyG, + prettyH), + g.size(), + h.size()); // assert equal nodes on each level for (int level = 0; level < g.numLevels(); level++) { - NodesIterator nodesOnLevel = g.getNodesOnLevel(level); - NodesIterator nodesOnLevel2 = h.getNodesOnLevel(level); - while (nodesOnLevel.hasNext() && nodesOnLevel2.hasNext()) { - int node = nodesOnLevel.nextInt(); - int node2 = nodesOnLevel2.nextInt(); - assertEquals("nodes in the graphs are different", node, node2); - } + List hNodes = sortedNodesOnLevel(h, level); + List gNodes = sortedNodesOnLevel(g, level); + assertEquals( + String.format( + Locale.ROOT, + "nodes in the graphs are different on level %d:%n%s%n%s", + level, + prettyG, + prettyH), + gNodes, + hNodes); } // assert equal nodes' neighbours on each level @@ -287,7 +319,16 @@ void assertGraphEqual(HnswGraph g, HnswGraph h) throws IOException { int node = nodesOnLevel.nextInt(); g.seek(level, node); h.seek(level, node); - assertEquals("arcs differ for node " + node, getNeighborNodes(g), getNeighborNodes(h)); + assertEquals( + String.format( + Locale.ROOT, + "arcs differ for node %d on level %d:%n%s%n%s", + node, + level, + prettyG, + prettyH), + getNeighborNodes(g), + getNeighborNodes(h)); } } } @@ -495,14 +536,12 @@ public void testBuildOnHeapHnswGraphOutOfOrder() throws IOException { } for (int currLevel = 1; currLevel < numLevels; currLevel++) { - NodesIterator nodesIterator = bottomUpExpectedHnsw.getNodesOnLevel(currLevel); List expectedNodesOnLevel = nodesPerLevel.get(currLevel); - assertEquals(expectedNodesOnLevel.size(), nodesIterator.size()); - for (Integer expectedNode : expectedNodesOnLevel) { - int currentNode = nodesIterator.nextInt(); - assertEquals(expectedNode.intValue(), currentNode); - assertEquals(0, bottomUpExpectedHnsw.getNeighbors(currLevel, currentNode).size()); - } + List sortedNodes = sortedNodesOnLevel(bottomUpExpectedHnsw, currLevel); + assertEquals( + String.format(Locale.ROOT, "Nodes on level %d do not match", currLevel), + expectedNodesOnLevel, + sortedNodes); } assertGraphEqual(bottomUpExpectedHnsw, topDownOrderReversedHnsw); @@ -607,13 +646,10 @@ private void assertGraphInitializedFromGraph( // assert the nodes from the previous graph are successfully to levels > 0 in the new graph for (int level = 1; level < g.numLevels(); level++) { - NodesIterator nodesOnLevel = g.getNodesOnLevel(level); - NodesIterator nodesOnLevel2 = h.getNodesOnLevel(level); - while (nodesOnLevel.hasNext() && nodesOnLevel2.hasNext()) { - int node = nodesOnLevel.nextInt(); - int node2 = oldToNewOrdMap.get(nodesOnLevel2.nextInt()); - assertEquals("nodes in the graphs are different", node, node2); - } + List nodesOnLevel = sortedNodesOnLevel(g, level); + List nodesOnLevel2 = + sortedNodesOnLevel(h, level).stream().map(oldToNewOrdMap::get).toList(); + assertEquals(nodesOnLevel, nodesOnLevel2); } // assert that the neighbors from the old graph are successfully transferred to the new graph @@ -1196,4 +1232,34 @@ static byte[] randomVector8(Random random, int dim) { } return bvec; } + + static String prettyPrint(HnswGraph hnsw) { + StringBuilder sb = new StringBuilder(); + sb.append(hnsw); + sb.append("\n"); + + try { + for (int level = 0; level < hnsw.numLevels(); level++) { + sb.append("# Level ").append(level).append("\n"); + NodesIterator it = hnsw.getNodesOnLevel(level); + while (it.hasNext()) { + int node = it.nextInt(); + sb.append(" ").append(node).append(" -> "); + hnsw.seek(level, node); + while (true) { + int neighbor = hnsw.nextNeighbor(); + if (neighbor == NO_MORE_DOCS) { + break; + } + sb.append(" ").append(neighbor); + } + sb.append("\n"); + } + } + } catch (IOException e) { + throw new RuntimeException(e); + } + + return sb.toString(); + } }