From 98fbbe2d673a22d3d46d3a0d9c0f1f65d7a8ed5c Mon Sep 17 00:00:00 2001 From: Michael Sokolov Date: Thu, 8 Aug 2024 14:41:52 -0400 Subject: [PATCH] gh-12627: HnswGraphBuilder connects disconnected HNSW graph components (#13566) --- lucene/CHANGES.txt | 2 + .../lucene99/Lucene99HnswVectorsReader.java | 1 + .../lucene99/Lucene99HnswVectorsWriter.java | 2 +- .../apache/lucene/util/hnsw/HnswBuilder.java | 8 + .../util/hnsw/HnswConcurrentMergeBuilder.java | 19 +- .../lucene/util/hnsw/HnswGraphBuilder.java | 103 +++++- .../org/apache/lucene/util/hnsw/HnswUtil.java | 268 +++++++++++++++ .../lucene/util/hnsw/OnHeapHnswGraph.java | 5 +- .../BaseVectorSimilarityQueryTestCase.java | 8 +- .../apache/lucene/util/hnsw/TestHnswUtil.java | 315 ++++++++++++++++++ .../test-framework/src/java/module-info.java | 1 - .../lucene/tests/util/hnsw/HnswTestUtil.java | 132 -------- 12 files changed, 723 insertions(+), 141 deletions(-) create mode 100644 lucene/core/src/java/org/apache/lucene/util/hnsw/HnswUtil.java create mode 100644 lucene/core/src/test/org/apache/lucene/util/hnsw/TestHnswUtil.java delete mode 100644 lucene/test-framework/src/java/org/apache/lucene/tests/util/hnsw/HnswTestUtil.java diff --git a/lucene/CHANGES.txt b/lucene/CHANGES.txt index b7f12719f310..1de2cf7ffb83 100644 --- a/lucene/CHANGES.txt +++ b/lucene/CHANGES.txt @@ -45,6 +45,8 @@ Improvements * GITHUB#13633: Add ability to read/write knn vector values to a MemoryIndex. (Ben Trent) +* GITHUB#12627: patch HNSW graphs to improve reachability of all nodes from entry points + Optimizations --------------------- diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99HnswVectorsReader.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99HnswVectorsReader.java index e388fd110a15..152d409ce7d2 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99HnswVectorsReader.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99HnswVectorsReader.java @@ -460,6 +460,7 @@ public void seek(int level, int targetOrd) throws IOException { // unsafe; no bounds checking dataIn.seek(graphLevelNodeOffsets.get(targetIndex + graphLevelNodeIndexOffsets[level])); arcCount = dataIn.readVInt(); + assert arcCount <= currentNeighborsBuffer.length : "too many neighbors: " + arcCount; if (arcCount > 0) { currentNeighborsBuffer[0] = dataIn.readVInt(); for (int i = 1; i < arcCount; i++) { diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99HnswVectorsWriter.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99HnswVectorsWriter.java index e2de93ffc251..449ae5e80187 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99HnswVectorsWriter.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99HnswVectorsWriter.java @@ -633,7 +633,7 @@ public T copyValue(T vectorValue) { throw new UnsupportedOperationException(); } - OnHeapHnswGraph getGraph() { + OnHeapHnswGraph getGraph() throws IOException { assert flatFieldVectorsWriter.isFinished(); if (node > 0) { return hnswGraphBuilder.getGraph(); diff --git a/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswBuilder.java b/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswBuilder.java index 547385607af5..aa27525b7f10 100644 --- a/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswBuilder.java +++ b/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswBuilder.java @@ -41,4 +41,12 @@ public interface HnswBuilder { void setInfoStream(InfoStream infoStream); OnHeapHnswGraph getGraph(); + + /** + * Once this method is called no further updates to the graph are accepted (addGraphNode will + * throw IllegalStateException). Final modifications to the graph (eg patching up disconnected + * components, re-ordering node ids for better delta compression) may be triggered, so callers + * should expect this call to take some time. + */ + OnHeapHnswGraph getCompletedGraph() throws IOException; } diff --git a/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswConcurrentMergeBuilder.java b/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswConcurrentMergeBuilder.java index 7407f2c8f27d..fc37f9cb690e 100644 --- a/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswConcurrentMergeBuilder.java +++ b/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswConcurrentMergeBuilder.java @@ -42,6 +42,7 @@ public class HnswConcurrentMergeBuilder implements HnswBuilder { private final ConcurrentMergeWorker[] workers; private final HnswLock hnswLock; private InfoStream infoStream = InfoStream.getDefault(); + private boolean frozen; public HnswConcurrentMergeBuilder( TaskExecutor taskExecutor, @@ -87,7 +88,9 @@ public OnHeapHnswGraph build(int maxOrd) throws IOException { }); } taskExecutor.invokeAll(futures); - return workers[0].getGraph(); + finish(); + frozen = true; + return workers[0].getCompletedGraph(); } @Override @@ -103,6 +106,20 @@ public void setInfoStream(InfoStream infoStream) { } } + @Override + public OnHeapHnswGraph getCompletedGraph() throws IOException { + if (frozen == false) { + // should already have been called in build(), but just in case + finish(); + frozen = true; + } + return getGraph(); + } + + private void finish() throws IOException { + workers[0].finish(); + } + @Override public OnHeapHnswGraph getGraph() { return workers[0].getGraph(); diff --git a/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphBuilder.java b/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphBuilder.java index 8898d4f0a3ae..1d38f6b14204 100644 --- a/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphBuilder.java +++ b/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphBuilder.java @@ -18,8 +18,11 @@ package org.apache.lucene.util.hnsw; import static java.lang.Math.log; +import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS; import java.io.IOException; +import java.util.Comparator; +import java.util.List; import java.util.Locale; import java.util.Objects; import java.util.SplittableRandom; @@ -28,6 +31,7 @@ import org.apache.lucene.search.TopDocs; import org.apache.lucene.util.FixedBitSet; import org.apache.lucene.util.InfoStream; +import org.apache.lucene.util.hnsw.HnswUtil.Component; /** * Builder for HNSW graph. See {@link HnswGraph} for a gloss on the algorithm and the meaning of the @@ -66,6 +70,7 @@ public class HnswGraphBuilder implements HnswBuilder { protected final HnswLock hnswLock; private InfoStream infoStream = InfoStream.getDefault(); + private boolean frozen; public static HnswGraphBuilder create( RandomVectorScorerSupplier scorerSupplier, int M, int beamWidth, long seed) @@ -136,7 +141,7 @@ protected HnswGraphBuilder( HnswGraphSearcher graphSearcher) throws IOException { if (M <= 0) { - throw new IllegalArgumentException("maxConn must be positive"); + throw new IllegalArgumentException("M (max connections) must be positive"); } if (beamWidth <= 0) { throw new IllegalArgumentException("beamWidth must be positive"); @@ -168,6 +173,15 @@ public void setInfoStream(InfoStream infoStream) { this.infoStream = infoStream; } + @Override + public OnHeapHnswGraph getCompletedGraph() throws IOException { + if (!frozen) { + finish(); + frozen = true; + } + return getGraph(); + } + @Override public OnHeapHnswGraph getGraph() { return hnsw; @@ -389,6 +403,93 @@ private static int getRandomGraphLevel(double ml, SplittableRandom random) { return ((int) (-log(randDouble) * ml)); } + void finish() throws IOException { + connectComponents(); + } + + private void connectComponents() throws IOException { + long start = System.nanoTime(); + for (int level = 0; level < hnsw.numLevels(); level++) { + if (connectComponents(level) == false) { + if (infoStream.isEnabled(HNSW_COMPONENT)) { + infoStream.message(HNSW_COMPONENT, "connectComponents failed on level " + level); + } + } + } + if (infoStream.isEnabled(HNSW_COMPONENT)) { + infoStream.message( + HNSW_COMPONENT, "connectComponents " + (System.nanoTime() - start) / 1_000_000 + " ms"); + } + } + + private boolean connectComponents(int level) throws IOException { + FixedBitSet notFullyConnected = new FixedBitSet(hnsw.size()); + int maxConn = M; + if (level == 0) { + maxConn *= 2; + } + List components = HnswUtil.components(hnsw, level, notFullyConnected, maxConn); + boolean result = true; + if (components.size() > 1) { + // connect other components to the largest one + Component c0 = components.stream().max(Comparator.comparingInt(Component::size)).get(); + if (c0.start() == NO_MORE_DOCS) { + // the component is already fully connected - no room for new connections + return false; + } + // try for more connections? We only do one since otherwise they may become full + // while linking + GraphBuilderKnnCollector beam = new GraphBuilderKnnCollector(1); + int[] eps = new int[1]; + for (Component c : components) { + if (c != c0) { + beam.clear(); + eps[0] = c0.start(); + RandomVectorScorer scorer = scorerSupplier.scorer(c.start()); + // find the closest node in the largest component to the lowest-numbered node in this + // component that has room to make a connection + graphSearcher.searchLevel(beam, scorer, 0, eps, hnsw, notFullyConnected); + boolean linked = false; + while (beam.size() > 0) { + float score = beam.minimumScore(); + int c0node = beam.popNode(); + assert notFullyConnected.get(c0node); + // link the nodes + link(level, c0node, c.start(), score, notFullyConnected); + linked = true; + } + if (!linked) { + result = false; + } + } + } + } + return result; + } + + // Try to link two nodes bidirectionally; the forward connection will always be made. + // Update notFullyConnected. + private void link(int level, int n0, int n1, float score, FixedBitSet notFullyConnected) { + NeighborArray nbr0 = hnsw.getNeighbors(level, n0); + NeighborArray nbr1 = hnsw.getNeighbors(level, n1); + // must subtract 1 here since the nodes array is one larger than the configured + // max neighbors (M / 2M). + // We should have taken care of this check by searching for not-full nodes + int maxConn = nbr0.nodes().length - 1; + assert notFullyConnected.get(n0); + assert nbr0.size() < maxConn : "node " + n0 + " is full, has " + nbr0.size() + " friends"; + nbr0.addOutOfOrder(n1, score); + if (nbr0.size() == maxConn) { + notFullyConnected.clear(n0); + } + if (nbr1.size() < maxConn) { + nbr1.addOutOfOrder(n0, score); + if (nbr1.size() == maxConn) { + notFullyConnected.clear(n1); + } + } + } + /** * A restricted, specialized knnCollector that can be used when building a graph. * diff --git a/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswUtil.java b/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswUtil.java new file mode 100644 index 000000000000..b34ead39f708 --- /dev/null +++ b/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswUtil.java @@ -0,0 +1,268 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.lucene.util.hnsw; + +import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS; + +import java.io.IOException; +import java.util.ArrayDeque; +import java.util.ArrayList; +import java.util.Deque; +import java.util.List; +import java.util.stream.Collectors; +import org.apache.lucene.codecs.hnsw.HnswGraphProvider; +import org.apache.lucene.codecs.perfield.PerFieldKnnVectorsFormat; +import org.apache.lucene.index.CodecReader; +import org.apache.lucene.index.FilterLeafReader; +import org.apache.lucene.index.IndexReader; +import org.apache.lucene.index.LeafReaderContext; +import org.apache.lucene.util.FixedBitSet; + +/** Utilities for use in tests involving HNSW graphs */ +public class HnswUtil { + + // utility class; only has static methods + private HnswUtil() {} + + /* + For each level, check rooted components from previous level nodes, which are entry + points with the goal that each node should be reachable from *some* entry point. For each entry + point, compute a spanning tree, recording the nodes in a single shared bitset. + + Also record a bitset marking nodes that are not full to be used when reconnecting in order to + limit the search to include non-full nodes only. + */ + + /** Returns true if every node on every level is reachable from node 0. */ + static boolean isRooted(HnswGraph knnValues) throws IOException { + for (int level = 0; level < knnValues.numLevels(); level++) { + if (components(knnValues, level, null, 0).size() > 1) { + return false; + } + } + return true; + } + + /** + * Returns the sizes of the distinct graph components on level 0. If the graph is fully-rooted the + * list will have one entry. If it is empty, the returned list will be empty. + */ + static List componentSizes(HnswGraph hnsw) throws IOException { + return componentSizes(hnsw, 0); + } + + /** + * Returns the sizes of the distinct graph components on the given level. The forest starting at + * the entry points (nodes in the next highest level) is considered as a single component. If the + * entire graph is rooted in the entry points, that is every node is reachable from at least one + * entry point, the returned list will have a single entry. If the graph is empty, the returned + * list will be empty. + */ + static List componentSizes(HnswGraph hnsw, int level) throws IOException { + return components(hnsw, level, null, 0).stream() + .map(Component::size) + .collect(Collectors.toList()); + } + + // Finds orphaned components on the graph level. + static List components( + HnswGraph hnsw, int level, FixedBitSet notFullyConnected, int maxConn) throws IOException { + List components = new ArrayList<>(); + FixedBitSet connectedNodes = new FixedBitSet(hnsw.size()); + assert hnsw.size() == hnsw.getNodesOnLevel(0).size(); + int total = 0; + if (level >= hnsw.numLevels()) { + throw new IllegalArgumentException( + "Level " + level + " too large for graph with " + hnsw.numLevels() + " levels"); + } + HnswGraph.NodesIterator entryPoints; + // System.out.println("components level=" + level); + if (level == hnsw.numLevels() - 1) { + entryPoints = new HnswGraph.ArrayNodesIterator(new int[] {hnsw.entryNode()}, 1); + } else { + entryPoints = hnsw.getNodesOnLevel(level + 1); + } + while (entryPoints.hasNext()) { + int entryPoint = entryPoints.nextInt(); + Component component = + markRooted(hnsw, level, connectedNodes, notFullyConnected, maxConn, entryPoint); + total += component.size(); + } + int entryPoint; + if (notFullyConnected != null) { + entryPoint = notFullyConnected.nextSetBit(0); + } else { + entryPoint = connectedNodes.nextSetBit(0); + } + components.add(new Component(entryPoint, total)); + if (level == 0) { + int nextClear = nextClearBit(connectedNodes, 0); + while (nextClear != NO_MORE_DOCS) { + Component component = + markRooted(hnsw, level, connectedNodes, notFullyConnected, maxConn, nextClear); + assert component.size() > 0; + components.add(component); + total += component.size(); + nextClear = nextClearBit(connectedNodes, component.start()); + } + } else { + HnswGraph.NodesIterator nodes = hnsw.getNodesOnLevel(level); + while (nodes.hasNext()) { + int nextClear = nodes.nextInt(); + if (connectedNodes.get(nextClear)) { + continue; + } + Component component = + markRooted(hnsw, level, connectedNodes, notFullyConnected, maxConn, nextClear); + assert component.size() > 0; + components.add(component); + total += component.size(); + } + } + assert total == hnsw.getNodesOnLevel(level).size() + : "total=" + + total + + " level nodes on level " + + level + + " = " + + hnsw.getNodesOnLevel(level).size(); + return components; + } + + /** + * Count the nodes in a rooted component of the graph and set the bits of its nodes in + * connectedNodes bitset. Rooted means nodes that can be reached from a root node. + * + * @param hnswGraph the graph to check + * @param level the level of the graph to check + * @param connectedNodes a bitset the size of the entire graph with 1's indicating nodes that have + * been marked as connected. This method updates the bitset. + * @param notFullyConnected a bitset the size of the entire graph. On output, we mark nodes + * visited having fewer than maxConn connections. May be null. + * @param maxConn the maximum number of connections for any node (aka M). + * @param entryPoint a node id to start at + */ + private static Component markRooted( + HnswGraph hnswGraph, + int level, + FixedBitSet connectedNodes, + FixedBitSet notFullyConnected, + int maxConn, + int entryPoint) + throws IOException { + // Start at entry point and search all nodes on this level + // System.out.println("markRooted level=" + level + " entryPoint=" + entryPoint); + Deque stack = new ArrayDeque<>(); + stack.push(entryPoint); + int count = 0; + while (!stack.isEmpty()) { + int node = stack.pop(); + if (connectedNodes.get(node)) { + continue; + } + count++; + connectedNodes.set(node); + hnswGraph.seek(level, node); + int friendOrd; + int friendCount = 0; + while ((friendOrd = hnswGraph.nextNeighbor()) != NO_MORE_DOCS) { + ++friendCount; + stack.push(friendOrd); + } + if (friendCount < maxConn && notFullyConnected != null) { + notFullyConnected.set(node); + } + } + return new Component(entryPoint, count); + } + + private static int nextClearBit(FixedBitSet bits, int index) { + // Does not depend on the ghost bits being clear! + long[] barray = bits.getBits(); + assert index >= 0 && index < bits.length() : "index=" + index + ", numBits=" + bits.length(); + int i = index >> 6; + long word = ~(barray[i] >> index); // skip all the bits to the right of index + + int next = NO_MORE_DOCS; + if (word != 0) { + next = index + Long.numberOfTrailingZeros(word); + } else { + while (++i < barray.length) { + word = ~barray[i]; + if (word != 0) { + next = (i << 6) + Long.numberOfTrailingZeros(word); + break; + } + } + } + if (next >= bits.length()) { + return NO_MORE_DOCS; + } else { + return next; + } + } + + /** + * In graph theory, "connected components" are really defined only for undirected (ie + * bidirectional) graphs. Our graphs are directed, because of pruning, but they are *mostly* + * undirected. In this case we compute components starting from a single node so what we are + * really measuring is whether the graph is a "rooted graph". TODO: measure whether the graph is + * "strongly connected" ie there is a path from every node to every other node. + */ + public static boolean graphIsRooted(IndexReader reader, String vectorField) throws IOException { + for (LeafReaderContext ctx : reader.leaves()) { + CodecReader codecReader = (CodecReader) FilterLeafReader.unwrap(ctx.reader()); + HnswGraph graph = + ((HnswGraphProvider) + ((PerFieldKnnVectorsFormat.FieldsReader) codecReader.getVectorReader()) + .getFieldReader(vectorField)) + .getGraph(vectorField); + if (isRooted(graph) == false) { + return false; + } + } + return true; + } + + /** + * A component (also "connected component") of an undirected graph is a collection of nodes that + * are connected by neighbor links: every node in a connected component is reachable from every + * other node in the component. See https://en.wikipedia.org/wiki/Component_(graph_theory). Such a + * graph is said to be "fully connected" iff it has a single component, or it is empty. + */ + static final class Component { + final int start; + final int size; + + /** + * @param start the lowest-numbered node in the component + * @param size the number of nodes in the component + */ + Component(int start, int size) { + this.start = start; + this.size = size; + } + + int start() { + return start; + } + + int size() { + return size; + } + } +} diff --git a/lucene/core/src/java/org/apache/lucene/util/hnsw/OnHeapHnswGraph.java b/lucene/core/src/java/org/apache/lucene/util/hnsw/OnHeapHnswGraph.java index cac58f39638b..4758e6464e9f 100644 --- a/lucene/core/src/java/org/apache/lucene/util/hnsw/OnHeapHnswGraph.java +++ b/lucene/core/src/java/org/apache/lucene/util/hnsw/OnHeapHnswGraph.java @@ -90,7 +90,10 @@ public final class OnHeapHnswGraph extends HnswGraph implements Accountable { * @param node the node whose neighbors are returned, represented as an ordinal on the level 0. */ public NeighborArray getNeighbors(int level, int node) { - assert graph[node][level] != null; + assert node < graph.length; + assert level < graph[node].length + : "level=" + level + ", node has only " + graph[node].length + " levels"; + assert graph[node][level] != null : "node=" + node + ", level=" + level; return graph[node][level]; } diff --git a/lucene/core/src/test/org/apache/lucene/search/BaseVectorSimilarityQueryTestCase.java b/lucene/core/src/test/org/apache/lucene/search/BaseVectorSimilarityQueryTestCase.java index 5c324ee03852..0395dae807e5 100644 --- a/lucene/core/src/test/org/apache/lucene/search/BaseVectorSimilarityQueryTestCase.java +++ b/lucene/core/src/test/org/apache/lucene/search/BaseVectorSimilarityQueryTestCase.java @@ -37,7 +37,7 @@ import org.apache.lucene.store.Directory; import org.apache.lucene.tests.index.RandomIndexWriter; import org.apache.lucene.tests.util.LuceneTestCase; -import org.apache.lucene.tests.util.hnsw.HnswTestUtil; +import org.apache.lucene.util.hnsw.HnswUtil; @LuceneTestCase.SuppressCodecs("SimpleText") abstract class BaseVectorSimilarityQueryTestCase< @@ -131,7 +131,7 @@ public void testExtremes() throws IOException { try (Directory indexStore = getIndexStore(getRandomVectors(numDocs, dim)); IndexReader reader = DirectoryReader.open(indexStore)) { IndexSearcher searcher = newSearcher(reader); - assumeTrue("graph is disconnected", HnswTestUtil.graphIsConnected(reader, vectorField)); + assumeTrue("graph is disconnected", HnswUtil.graphIsRooted(reader, vectorField)); // All vectors are above -Infinity Query query1 = @@ -167,7 +167,7 @@ public void testRandomFilter() throws IOException { try (Directory indexStore = getIndexStore(getRandomVectors(numDocs, dim)); IndexReader reader = DirectoryReader.open(indexStore)) { - assumeTrue("graph is disconnected", HnswTestUtil.graphIsConnected(reader, vectorField)); + assumeTrue("graph is disconnected", HnswUtil.graphIsRooted(reader, vectorField)); IndexSearcher searcher = newSearcher(reader); Query query = @@ -292,7 +292,7 @@ public void testSomeDeletes() throws IOException { w.commit(); try (IndexReader reader = DirectoryReader.open(indexStore)) { - assumeTrue("graph is disconnected", HnswTestUtil.graphIsConnected(reader, vectorField)); + assumeTrue("graph is disconnected", HnswUtil.graphIsRooted(reader, vectorField)); IndexSearcher searcher = newSearcher(reader); Query query = diff --git a/lucene/core/src/test/org/apache/lucene/util/hnsw/TestHnswUtil.java b/lucene/core/src/test/org/apache/lucene/util/hnsw/TestHnswUtil.java new file mode 100644 index 000000000000..b001831f65c1 --- /dev/null +++ b/lucene/core/src/test/org/apache/lucene/util/hnsw/TestHnswUtil.java @@ -0,0 +1,315 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.lucene.util.hnsw; + +import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS; + +import java.util.ArrayDeque; +import java.util.Arrays; +import java.util.List; +import org.apache.lucene.tests.util.LuceneTestCase; +import org.apache.lucene.util.FixedBitSet; + +public class TestHnswUtil extends LuceneTestCase { + + public void testTreeWithCycle() throws Exception { + // test a graph that is a tree - this is rooted from its root node, not rooted + // from any other node, and not strongly connected + int[][][] nodes = { + { + {1, 2}, // node 0 + {3, 4}, // node 1 + {5, 6}, // node 2 + {}, {}, {}, {0} + } + }; + HnswGraph graph = new MockGraph(nodes); + assertTrue(HnswUtil.isRooted(graph)); + assertEquals(List.of(7), HnswUtil.componentSizes(graph)); + } + + public void testBackLinking() throws Exception { + // test a graph that is a tree - this is rooted from its root node, not rooted + // from any other node, and not strongly connected + int[][][] nodes = { + { + {1, 2}, // node 0 + {3, 4}, // node 1 + {0}, // node 2 + {1}, {1}, {1}, {1} + } + }; + HnswGraph graph = new MockGraph(nodes); + assertFalse(HnswUtil.isRooted(graph)); + // [ {0, 1, 2, 3, 4}, {5}, {6} + assertEquals(List.of(5, 1, 1), HnswUtil.componentSizes(graph)); + } + + public void testChain() throws Exception { + // test a graph that is a chain - this is rooted from every node, thus strongly connected + int[][][] nodes = {{{1}, {2}, {3}, {0}}}; + HnswGraph graph = new MockGraph(nodes); + assertTrue(HnswUtil.isRooted(graph)); + assertEquals(List.of(4), HnswUtil.componentSizes(graph)); + } + + public void testTwoChains() throws Exception { + // test a graph that is two chains + int[][][] nodes = {{{2}, {3}, {0}, {1}}}; + HnswGraph graph = new MockGraph(nodes); + assertFalse(HnswUtil.isRooted(graph)); + assertEquals(List.of(2, 2), HnswUtil.componentSizes(graph)); + } + + public void testLevels() throws Exception { + // test a graph that has three levels + int[][][] nodes = { + {{1, 2}, {3}, {0}, {0}}, + {{2}, null, {0}, null}, + {{}, null, null, null} + }; + HnswGraph graph = new MockGraph(nodes); + // System.out.println(graph.toString()); + assertTrue(HnswUtil.isRooted(graph)); + assertEquals(List.of(4), HnswUtil.componentSizes(graph)); + } + + public void testLevelsNotRooted() throws Exception { + // test a graph that has two levels with an orphaned node + int[][][] nodes = { + {{1}, {0}, {0}}, + {{}, null, null} + }; + HnswGraph graph = new MockGraph(nodes); + assertFalse(HnswUtil.isRooted(graph)); + assertEquals(List.of(2, 1), HnswUtil.componentSizes(graph)); + } + + public void testRandom() throws Exception { + for (int i = 0; i < atLeast(10); i++) { + // test on a random directed graph comparing against a brute force algorithm + int numNodes = random().nextInt(99) + 1; + int numLevels = (int) Math.ceil(Math.log(numNodes)); + int[][][] nodes = new int[numLevels][][]; + for (int level = numLevels - 1; level >= 0; level--) { + nodes[level] = new int[numNodes][]; + for (int node = 0; node < numNodes; node++) { + if (level > 0) { + if ((level == numLevels - 1 && node > 0) + || (level < numLevels - 1 && nodes[level + 1][node] == null)) { + if (random().nextFloat() > Math.pow(Math.E, -level)) { + // skip some nodes, more on higher levels while ensuring every node present on a + // given level is present on all lower levels. Also ensure node 0 is always present. + continue; + } + } + } + int numNbrs = random().nextInt((numNodes + 7) / 8); + if (level == 0) { + numNbrs *= 2; + } + nodes[level][node] = new int[numNbrs]; + for (int nbr = 0; nbr < numNbrs; nbr++) { + while (true) { + int randomNbr = random().nextInt(numNodes); + if (nodes[level][randomNbr] != null) { + // allow self-linking; this doesn't arise in HNSW but it's valid more generally + nodes[level][node][nbr] = randomNbr; + break; + } + // nbr not on this level, try again + } + } + } + } + MockGraph graph = new MockGraph(nodes); + /**/ + if (i == 2) { + System.out.println("iter " + i); + System.out.print(graph.toString()); + } + /**/ + assertEquals(isRooted(nodes), HnswUtil.isRooted(graph)); + } + } + + private boolean isRooted(int[][][] nodes) { + for (int level = nodes.length - 1; level >= 0; level--) { + if (isRooted(nodes, level) == false) { + return false; + } + } + return true; + } + + private boolean isRooted(int[][][] nodes, int level) { + // check that the graph is rooted in the union of the entry nodes' trees + // System.out.println("isRooted level=" + level); + int[][] entryPoints; + if (level == nodes.length - 1) { + // entry into the top level is from a single entry point, fixed at 0 + entryPoints = new int[][] {nodes[level][0]}; + } else { + entryPoints = nodes[level + 1]; + } + FixedBitSet connected = new FixedBitSet(nodes[level].length); + int count = 0; + for (int entryPoint = 0; entryPoint < entryPoints.length; entryPoint++) { + if (entryPoints[entryPoint] == null) { + // use nodes present on next higher level (or this level if top level) as entry points + continue; + } + // System.out.println(" isRooted level=" + level + " entryPoint=" + entryPoint); + ArrayDeque stack = new ArrayDeque<>(); + stack.push(entryPoint); + while (!stack.isEmpty()) { + int node = stack.pop(); + if (connected.get(node)) { + continue; + } + // System.out.println(" connected node=" + node); + connected.set(node); + count++; + for (int nbr : nodes[level][node]) { + stack.push(nbr); + } + } + } + return count == levelSize(nodes[level]); + } + + static int levelSize(int[][] nodes) { + int count = 0; + for (int[] node : nodes) { + if (node != null) { + ++count; + } + } + return count; + } + + /** Empty graph value */ + static class MockGraph extends HnswGraph { + + private final int[][][] nodes; + + private int currentLevel; + private int currentNode; + private int currentNeighbor; + + MockGraph(int[][][] nodes) { + this.nodes = nodes; + } + + @Override + public int nextNeighbor() { + if (currentNeighbor >= nodes[currentLevel][currentNode].length) { + return NO_MORE_DOCS; + } else { + return nodes[currentLevel][currentNode][currentNeighbor++]; + } + } + + @Override + public void seek(int level, int target) { + assert level >= 0 && level < nodes.length; + assert target >= 0 && target < nodes[level].length + : "target out of range: " + + target + + " for level " + + level + + "; should be less than " + + nodes[level].length; + assert nodes[level][target] != null : "target " + target + " not on level " + level; + currentLevel = level; + currentNode = target; + currentNeighbor = 0; + } + + @Override + public int size() { + return nodes[0].length; + } + + @Override + public int numLevels() { + return nodes.length; + } + + @Override + public int entryNode() { + return 0; + } + + @Override + public String toString() { + StringBuilder buf = new StringBuilder(); + for (int level = nodes.length - 1; level >= 0; level--) { + buf.append("\nLEVEL ").append(level).append("\n"); + for (int node = 0; node < nodes[level].length; node++) { + if (nodes[level][node] != null) { + buf.append(" ") + .append(node) + .append(':') + .append(Arrays.toString(nodes[level][node])) + .append("\n"); + } + } + } + return buf.toString(); + } + + @Override + public NodesIterator getNodesOnLevel(int level) { + + int count = 0; + for (int i = 0; i < nodes[level].length; i++) { + if (nodes[level][i] != null) { + count++; + } + } + final int finalCount = count; + + return new NodesIterator(finalCount) { + int cur = -1; + int curCount = 0; + + @Override + public boolean hasNext() { + return curCount < finalCount; + } + + @Override + public int nextInt() { + while (curCount < finalCount) { + if (nodes[level][++cur] != null) { + curCount++; + return cur; + } + } + throw new IllegalStateException("exhausted"); + } + + @Override + public int consume(int[] dest) { + throw new UnsupportedOperationException(); + } + }; + } + } +} diff --git a/lucene/test-framework/src/java/module-info.java b/lucene/test-framework/src/java/module-info.java index 3e6311bc697a..f366d1f52b78 100644 --- a/lucene/test-framework/src/java/module-info.java +++ b/lucene/test-framework/src/java/module-info.java @@ -49,7 +49,6 @@ exports org.apache.lucene.tests.store; exports org.apache.lucene.tests.util.automaton; exports org.apache.lucene.tests.util.fst; - exports org.apache.lucene.tests.util.hnsw; exports org.apache.lucene.tests.util; provides org.apache.lucene.codecs.Codec with diff --git a/lucene/test-framework/src/java/org/apache/lucene/tests/util/hnsw/HnswTestUtil.java b/lucene/test-framework/src/java/org/apache/lucene/tests/util/hnsw/HnswTestUtil.java deleted file mode 100644 index 955665544bcc..000000000000 --- a/lucene/test-framework/src/java/org/apache/lucene/tests/util/hnsw/HnswTestUtil.java +++ /dev/null @@ -1,132 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.lucene.tests.util.hnsw; - -import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS; - -import java.io.IOException; -import java.util.ArrayDeque; -import java.util.ArrayList; -import java.util.Deque; -import java.util.List; -import org.apache.lucene.codecs.hnsw.HnswGraphProvider; -import org.apache.lucene.codecs.perfield.PerFieldKnnVectorsFormat; -import org.apache.lucene.index.CodecReader; -import org.apache.lucene.index.FilterLeafReader; -import org.apache.lucene.index.IndexReader; -import org.apache.lucene.index.LeafReaderContext; -import org.apache.lucene.util.FixedBitSet; -import org.apache.lucene.util.hnsw.HnswGraph; - -/** Utilities for use in tests involving HNSW graphs */ -public class HnswTestUtil { - - /** - * Returns true iff level 0 of the graph is fully connected - that is every node is reachable from - * any entry point. - */ - public static boolean isFullyConnected(HnswGraph knnValues) throws IOException { - return componentSizes(knnValues).size() < 2; - } - - /** - * Returns the sizes of the distinct graph components on level 0. If the graph is fully-connected - * there will only be a single component. If the graph is empty, the returned list will be empty. - */ - public static List componentSizes(HnswGraph hnsw) throws IOException { - List sizes = new ArrayList<>(); - FixedBitSet connectedNodes = new FixedBitSet(hnsw.size()); - assert hnsw.size() == hnsw.getNodesOnLevel(0).size(); - int total = 0; - while (total < connectedNodes.length()) { - int componentSize = traverseConnectedNodes(hnsw, connectedNodes); - assert componentSize > 0; - sizes.add(componentSize); - total += componentSize; - } - return sizes; - } - - // count the nodes in a connected component of the graph and set the bits of its nodes in - // connectedNodes bitset - private static int traverseConnectedNodes(HnswGraph hnswGraph, FixedBitSet connectedNodes) - throws IOException { - // Start at entry point and search all nodes on this level - int entryPoint = nextClearBit(connectedNodes, 0); - if (entryPoint == NO_MORE_DOCS) { - return 0; - } - Deque stack = new ArrayDeque<>(); - stack.push(entryPoint); - int count = 0; - while (!stack.isEmpty()) { - int node = stack.pop(); - if (connectedNodes.get(node)) { - continue; - } - count++; - connectedNodes.set(node); - hnswGraph.seek(0, node); - int friendOrd; - while ((friendOrd = hnswGraph.nextNeighbor()) != NO_MORE_DOCS) { - stack.push(friendOrd); - } - } - return count; - } - - private static int nextClearBit(FixedBitSet bits, int index) { - // Does not depend on the ghost bits being clear! - long[] barray = bits.getBits(); - assert index >= 0 && index < bits.length() : "index=" + index + ", numBits=" + bits.length(); - int i = index >> 6; - long word = ~(barray[i] >> index); // skip all the bits to the right of index - - if (word != 0) { - return index + Long.numberOfTrailingZeros(word); - } - - while (++i < barray.length) { - word = ~barray[i]; - if (word != 0) { - int next = (i << 6) + Long.numberOfTrailingZeros(word); - if (next >= bits.length()) { - return NO_MORE_DOCS; - } else { - return next; - } - } - } - return NO_MORE_DOCS; - } - - public static boolean graphIsConnected(IndexReader reader, String vectorField) - throws IOException { - for (LeafReaderContext ctx : reader.leaves()) { - CodecReader codecReader = (CodecReader) FilterLeafReader.unwrap(ctx.reader()); - HnswGraph graph = - ((HnswGraphProvider) - ((PerFieldKnnVectorsFormat.FieldsReader) codecReader.getVectorReader()) - .getFieldReader(vectorField)) - .getGraph(vectorField); - if (isFullyConnected(graph) == false) { - return false; - } - } - return true; - } -}