diff --git a/lucene/CHANGES.txt b/lucene/CHANGES.txt index 1bc6cfe3ccd9..a55713e33ef6 100644 --- a/lucene/CHANGES.txt +++ b/lucene/CHANGES.txt @@ -133,6 +133,8 @@ Optimizations * GITHUB#12128, GITHUB#12133: Speed up docvalues set query by making use of sortedness. (Robert Muir, Uwe Schindler) +* GITHUB#12050: Reuse HNSW graph for intialization during merge (Jack Mazanec) + Bug Fixes --------------------- (No changes) diff --git a/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene91/Lucene91HnswVectorsReader.java b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene91/Lucene91HnswVectorsReader.java index 2b46cb498ab1..8e66d0755878 100644 --- a/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene91/Lucene91HnswVectorsReader.java +++ b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene91/Lucene91HnswVectorsReader.java @@ -561,9 +561,9 @@ public int entryNode() { @Override public NodesIterator getNodesOnLevel(int level) { if (level == 0) { - return new NodesIterator(size()); + return new ArrayNodesIterator(size()); } else { - return new NodesIterator(nodesByLevel[level], nodesByLevel[level].length); + return new ArrayNodesIterator(nodesByLevel[level], nodesByLevel[level].length); } } } diff --git a/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene91/Lucene91OnHeapHnswGraph.java b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene91/Lucene91OnHeapHnswGraph.java index 2d3ef582b472..e762e016bbff 100644 --- a/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene91/Lucene91OnHeapHnswGraph.java +++ b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene91/Lucene91OnHeapHnswGraph.java @@ -163,9 +163,9 @@ public int entryNode() { @Override public NodesIterator getNodesOnLevel(int level) { if (level == 0) { - return new NodesIterator(size()); + return new ArrayNodesIterator(size()); } else { - return new NodesIterator(nodesByLevel.get(level), graph.get(level).size()); + return new ArrayNodesIterator(nodesByLevel.get(level), graph.get(level).size()); } } } diff --git a/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene92/Lucene92HnswVectorsReader.java b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene92/Lucene92HnswVectorsReader.java index 8fb1b3a92d36..df51972e8ddf 100644 --- a/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene92/Lucene92HnswVectorsReader.java +++ b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene92/Lucene92HnswVectorsReader.java @@ -457,9 +457,9 @@ public int entryNode() { @Override public NodesIterator getNodesOnLevel(int level) { if (level == 0) { - return new NodesIterator(size()); + return new ArrayNodesIterator(size()); } else { - return new NodesIterator(nodesByLevel[level], nodesByLevel[level].length); + return new ArrayNodesIterator(nodesByLevel[level], nodesByLevel[level].length); } } } diff --git a/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene94/Lucene94HnswVectorsReader.java b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene94/Lucene94HnswVectorsReader.java index 51a9aa23c271..ee6472ab2d6c 100644 --- a/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene94/Lucene94HnswVectorsReader.java +++ b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene94/Lucene94HnswVectorsReader.java @@ -533,9 +533,9 @@ public int entryNode() { @Override public NodesIterator getNodesOnLevel(int level) { if (level == 0) { - return new NodesIterator(size()); + return new ArrayNodesIterator(size()); } else { - return new NodesIterator(nodesByLevel[level], nodesByLevel[level].length); + return new ArrayNodesIterator(nodesByLevel[level], nodesByLevel[level].length); } } } diff --git a/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene94/Lucene94HnswVectorsWriter.java b/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene94/Lucene94HnswVectorsWriter.java index 2e125c824629..f6f378027603 100644 --- a/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene94/Lucene94HnswVectorsWriter.java +++ b/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene94/Lucene94HnswVectorsWriter.java @@ -345,7 +345,7 @@ public NodesIterator getNodesOnLevel(int level) { if (level == 0) { return graph.getNodesOnLevel(0); } else { - return new NodesIterator(nodesByLevel.get(level), nodesByLevel.get(level).length); + return new ArrayNodesIterator(nodesByLevel.get(level), nodesByLevel.get(level).length); } } }; @@ -687,10 +687,7 @@ public void addValue(int docID, Object value) throws IOException { assert docID > lastDocID; docsWithField.add(docID); vectors.add(copyValue(vectorValue)); - if (node > 0) { - // start at node 1! node 0 is added implicitly, in the constructor - hnswGraphBuilder.addGraphNode(node, vectorValue); - } + hnswGraphBuilder.addGraphNode(node, vectorValue); node++; lastDocID = docID; } diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene95/Lucene95HnswVectorsReader.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene95/Lucene95HnswVectorsReader.java index 8d140b1fa1e0..185a472b5f48 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/lucene95/Lucene95HnswVectorsReader.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene95/Lucene95HnswVectorsReader.java @@ -573,9 +573,9 @@ public int entryNode() throws IOException { @Override public NodesIterator getNodesOnLevel(int level) { if (level == 0) { - return new NodesIterator(size()); + return new ArrayNodesIterator(size()); } else { - return new NodesIterator(nodesByLevel[level], nodesByLevel[level].length); + return new ArrayNodesIterator(nodesByLevel[level], nodesByLevel[level].length); } } } diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene95/Lucene95HnswVectorsWriter.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene95/Lucene95HnswVectorsWriter.java index e9180d63d33d..bf0b79807f06 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/lucene95/Lucene95HnswVectorsWriter.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene95/Lucene95HnswVectorsWriter.java @@ -25,11 +25,16 @@ import java.nio.ByteOrder; import java.util.ArrayList; import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; import java.util.List; +import java.util.Map; import org.apache.lucene.codecs.CodecUtil; import org.apache.lucene.codecs.KnnFieldVectorsWriter; +import org.apache.lucene.codecs.KnnVectorsReader; import org.apache.lucene.codecs.KnnVectorsWriter; import org.apache.lucene.codecs.lucene90.IndexedDISI; +import org.apache.lucene.codecs.perfield.PerFieldKnnVectorsFormat; import org.apache.lucene.index.*; import org.apache.lucene.index.Sorter; import org.apache.lucene.search.DocIdSetIterator; @@ -357,7 +362,7 @@ public NodesIterator getNodesOnLevel(int level) { if (level == 0) { return graph.getNodesOnLevel(0); } else { - return new NodesIterator(nodesByLevel.get(level), nodesByLevel.get(level).length); + return new ArrayNodesIterator(nodesByLevel.get(level), nodesByLevel.get(level).length); } } }; @@ -424,6 +429,7 @@ public void mergeOneField(FieldInfo fieldInfo, MergeState mergeState) throws IOE int[][] vectorIndexNodeOffsets = null; if (docsWithField.cardinality() != 0) { // build graph + int initializerIndex = selectGraphForInitialization(mergeState, fieldInfo); graph = switch (fieldInfo.getVectorEncoding()) { case BYTE -> { @@ -434,13 +440,7 @@ public void mergeOneField(FieldInfo fieldInfo, MergeState mergeState) throws IOE vectorDataInput, byteSize); HnswGraphBuilder hnswGraphBuilder = - HnswGraphBuilder.create( - vectorValues, - fieldInfo.getVectorEncoding(), - fieldInfo.getVectorSimilarityFunction(), - M, - beamWidth, - HnswGraphBuilder.randSeed); + createHnswGraphBuilder(mergeState, fieldInfo, vectorValues, initializerIndex); hnswGraphBuilder.setInfoStream(segmentWriteState.infoStream); yield hnswGraphBuilder.build(vectorValues.copy()); } @@ -452,13 +452,7 @@ public void mergeOneField(FieldInfo fieldInfo, MergeState mergeState) throws IOE vectorDataInput, byteSize); HnswGraphBuilder hnswGraphBuilder = - HnswGraphBuilder.create( - vectorValues, - fieldInfo.getVectorEncoding(), - fieldInfo.getVectorSimilarityFunction(), - M, - beamWidth, - HnswGraphBuilder.randSeed); + createHnswGraphBuilder(mergeState, fieldInfo, vectorValues, initializerIndex); hnswGraphBuilder.setInfoStream(segmentWriteState.infoStream); yield hnswGraphBuilder.build(vectorValues.copy()); } @@ -489,6 +483,189 @@ public void mergeOneField(FieldInfo fieldInfo, MergeState mergeState) throws IOE } } + private HnswGraphBuilder createHnswGraphBuilder( + MergeState mergeState, + FieldInfo fieldInfo, + RandomAccessVectorValues floatVectorValues, + int initializerIndex) + throws IOException { + if (initializerIndex == -1) { + return HnswGraphBuilder.create( + floatVectorValues, + fieldInfo.getVectorEncoding(), + fieldInfo.getVectorSimilarityFunction(), + M, + beamWidth, + HnswGraphBuilder.randSeed); + } + + HnswGraph initializerGraph = + getHnswGraphFromReader(fieldInfo.name, mergeState.knnVectorsReaders[initializerIndex]); + Map ordinalMapper = + getOldToNewOrdinalMap(mergeState, fieldInfo, initializerIndex); + return HnswGraphBuilder.create( + floatVectorValues, + fieldInfo.getVectorEncoding(), + fieldInfo.getVectorSimilarityFunction(), + M, + beamWidth, + HnswGraphBuilder.randSeed, + initializerGraph, + ordinalMapper); + } + + private int selectGraphForInitialization(MergeState mergeState, FieldInfo fieldInfo) + throws IOException { + // Find the KnnVectorReader with the most docs that meets the following criteria: + // 1. Does not contain any deleted docs + // 2. Is a Lucene95HnswVectorsReader/PerFieldKnnVectorReader + // If no readers exist that meet this criteria, return -1. If they do, return their index in + // merge state + int maxCandidateVectorCount = 0; + int initializerIndex = -1; + + for (int i = 0; i < mergeState.liveDocs.length; i++) { + KnnVectorsReader currKnnVectorsReader = mergeState.knnVectorsReaders[i]; + if (mergeState.knnVectorsReaders[i] + instanceof PerFieldKnnVectorsFormat.FieldsReader candidateReader) { + currKnnVectorsReader = candidateReader.getFieldReader(fieldInfo.name); + } + + if (!allMatch(mergeState.liveDocs[i]) + || !(currKnnVectorsReader instanceof Lucene95HnswVectorsReader candidateReader)) { + continue; + } + + int candidateVectorCount = 0; + switch (fieldInfo.getVectorEncoding()) { + case BYTE -> { + ByteVectorValues byteVectorValues = candidateReader.getByteVectorValues(fieldInfo.name); + if (byteVectorValues == null) { + continue; + } + candidateVectorCount = byteVectorValues.size(); + } + case FLOAT32 -> { + FloatVectorValues vectorValues = candidateReader.getFloatVectorValues(fieldInfo.name); + if (vectorValues == null) { + continue; + } + candidateVectorCount = vectorValues.size(); + } + } + + if (candidateVectorCount > maxCandidateVectorCount) { + maxCandidateVectorCount = candidateVectorCount; + initializerIndex = i; + } + } + return initializerIndex; + } + + private HnswGraph getHnswGraphFromReader(String fieldName, KnnVectorsReader knnVectorsReader) + throws IOException { + if (knnVectorsReader instanceof PerFieldKnnVectorsFormat.FieldsReader perFieldReader + && perFieldReader.getFieldReader(fieldName) + instanceof Lucene95HnswVectorsReader fieldReader) { + return fieldReader.getGraph(fieldName); + } + + if (knnVectorsReader instanceof Lucene95HnswVectorsReader) { + return ((Lucene95HnswVectorsReader) knnVectorsReader).getGraph(fieldName); + } + + // We should not reach here because knnVectorsReader's type is checked in + // selectGraphForInitialization + throw new IllegalArgumentException( + "Invalid KnnVectorsReader type for field: " + + fieldName + + ". Must be Lucene95HnswVectorsReader or newer"); + } + + private Map getOldToNewOrdinalMap( + MergeState mergeState, FieldInfo fieldInfo, int initializerIndex) throws IOException { + + DocIdSetIterator initializerIterator = null; + + switch (fieldInfo.getVectorEncoding()) { + case BYTE -> initializerIterator = + mergeState.knnVectorsReaders[initializerIndex].getByteVectorValues(fieldInfo.name); + case FLOAT32 -> initializerIterator = + mergeState.knnVectorsReaders[initializerIndex].getFloatVectorValues(fieldInfo.name); + } + + MergeState.DocMap initializerDocMap = mergeState.docMaps[initializerIndex]; + + Map newIdToOldOrdinal = new HashMap<>(); + int oldOrd = 0; + int maxNewDocID = -1; + for (int oldId = initializerIterator.nextDoc(); + oldId != NO_MORE_DOCS; + oldId = initializerIterator.nextDoc()) { + if (isCurrentVectorNull(initializerIterator)) { + continue; + } + int newId = initializerDocMap.get(oldId); + maxNewDocID = Math.max(newId, maxNewDocID); + newIdToOldOrdinal.put(newId, oldOrd); + oldOrd++; + } + + if (maxNewDocID == -1) { + return Collections.emptyMap(); + } + + Map oldToNewOrdinalMap = new HashMap<>(); + + DocIdSetIterator vectorIterator = null; + switch (fieldInfo.getVectorEncoding()) { + case BYTE -> vectorIterator = MergedVectorValues.mergeByteVectorValues(fieldInfo, mergeState); + case FLOAT32 -> vectorIterator = + MergedVectorValues.mergeFloatVectorValues(fieldInfo, mergeState); + } + + int newOrd = 0; + for (int newDocId = vectorIterator.nextDoc(); + newDocId <= maxNewDocID; + newDocId = vectorIterator.nextDoc()) { + if (isCurrentVectorNull(vectorIterator)) { + continue; + } + + if (newIdToOldOrdinal.containsKey(newDocId)) { + oldToNewOrdinalMap.put(newIdToOldOrdinal.get(newDocId), newOrd); + } + newOrd++; + } + + return oldToNewOrdinalMap; + } + + private boolean isCurrentVectorNull(DocIdSetIterator docIdSetIterator) throws IOException { + if (docIdSetIterator instanceof FloatVectorValues) { + return ((FloatVectorValues) docIdSetIterator).vectorValue() == null; + } + + if (docIdSetIterator instanceof ByteVectorValues) { + return ((ByteVectorValues) docIdSetIterator).vectorValue() == null; + } + + return true; + } + + private boolean allMatch(Bits bits) { + if (bits == null) { + return true; + } + + for (int i = 0; i < bits.length(); i++) { + if (!bits.get(i)) { + return false; + } + } + return true; + } + /** * @param graph Write the graph in a compressed format * @return The non-cumulative offsets for the nodes. Should be used to create cumulative offsets. @@ -735,10 +912,7 @@ public void addValue(int docID, T vectorValue) throws IOException { assert docID > lastDocID; docsWithField.add(docID); vectors.add(copyValue(vectorValue)); - if (node > 0) { - // start at node 1! node 0 is added implicitly, in the constructor - hnswGraphBuilder.addGraphNode(node, vectorValue); - } + hnswGraphBuilder.addGraphNode(node, vectorValue); node++; lastDocID = docID; } diff --git a/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraph.java b/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraph.java index fc7b0be82fbb..9086ab55d2eb 100644 --- a/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraph.java +++ b/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraph.java @@ -20,6 +20,8 @@ import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS; import java.io.IOException; +import java.util.Collection; +import java.util.Iterator; import java.util.NoSuchElementException; import java.util.PrimitiveIterator; import org.apache.lucene.index.FloatVectorValues; @@ -115,7 +117,7 @@ public int entryNode() { @Override public NodesIterator getNodesOnLevel(int level) { - return NodesIterator.EMPTY; + return ArrayNodesIterator.EMPTY; } }; @@ -123,33 +125,50 @@ public NodesIterator getNodesOnLevel(int level) { * Iterator over the graph nodes on a certain level, Iterator also provides the size – the total * number of nodes to be iterated over. */ - public static final class NodesIterator implements PrimitiveIterator.OfInt { - static NodesIterator EMPTY = new NodesIterator(0); - - private final int[] nodes; - private final int size; - int cur = 0; - - /** Constructor for iterator based on the nodes array up to the size */ - public NodesIterator(int[] nodes, int size) { - assert nodes != null; - assert size <= nodes.length; - this.nodes = nodes; - this.size = size; - } + public abstract static class NodesIterator implements PrimitiveIterator.OfInt { + protected final int size; /** Constructor for iterator based on the size */ public NodesIterator(int size) { - this.nodes = null; this.size = size; } + /** The number of elements in this iterator * */ + public int size() { + return size; + } + /** * Consume integers from the iterator and place them into the `dest` array. * * @param dest where to put the integers * @return The number of integers written to `dest` */ + public abstract int consume(int[] dest); + } + + /** NodesIterator that accepts nodes as an integer array. */ + public static class ArrayNodesIterator extends NodesIterator { + static NodesIterator EMPTY = new ArrayNodesIterator(0); + + private final int[] nodes; + private int cur = 0; + + /** Constructor for iterator based on integer array representing nodes */ + public ArrayNodesIterator(int[] nodes, int size) { + super(size); + assert nodes != null; + assert size <= nodes.length; + this.nodes = nodes; + } + + /** Constructor for iterator based on the size */ + public ArrayNodesIterator(int size) { + super(size); + this.nodes = null; + } + + @Override public int consume(int[] dest) { if (hasNext() == false) { throw new NoSuchElementException(); @@ -182,10 +201,43 @@ public int nextInt() { public boolean hasNext() { return cur < size; } + } - /** The number of elements in this iterator * */ - public int size() { - return size; + /** Nodes iterator based on set representation of nodes. */ + public static class CollectionNodesIterator extends NodesIterator { + Iterator nodes; + + /** Constructor for iterator based on collection representing nodes */ + public CollectionNodesIterator(Collection nodes) { + super(nodes.size()); + this.nodes = nodes.iterator(); + } + + @Override + public int consume(int[] dest) { + if (hasNext() == false) { + throw new NoSuchElementException(); + } + + int destIndex = 0; + while (hasNext() && destIndex < dest.length) { + dest[destIndex++] = nextInt(); + } + + return destIndex; + } + + @Override + public int nextInt() { + if (hasNext() == false) { + throw new NoSuchElementException(); + } + return nodes.next(); + } + + @Override + public boolean hasNext() { + return nodes.hasNext(); } } } diff --git a/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphBuilder.java b/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphBuilder.java index e29329932610..9f1e6c505254 100644 --- a/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphBuilder.java +++ b/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphBuilder.java @@ -18,10 +18,14 @@ package org.apache.lucene.util.hnsw; import static java.lang.Math.log; +import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS; import java.io.IOException; +import java.util.HashSet; import java.util.Locale; +import java.util.Map; import java.util.Objects; +import java.util.Set; import java.util.SplittableRandom; import java.util.concurrent.TimeUnit; import org.apache.lucene.index.VectorEncoding; @@ -63,6 +67,7 @@ public final class HnswGraphBuilder { // we need two sources of vectors in order to perform diversity check comparisons without // colliding private final RandomAccessVectorValues vectorsCopy; + private final Set initializedNodes; public static HnswGraphBuilder create( RandomAccessVectorValues vectors, @@ -75,6 +80,22 @@ public static HnswGraphBuilder create( return new HnswGraphBuilder<>(vectors, vectorEncoding, similarityFunction, M, beamWidth, seed); } + public static HnswGraphBuilder create( + RandomAccessVectorValues vectors, + VectorEncoding vectorEncoding, + VectorSimilarityFunction similarityFunction, + int M, + int beamWidth, + long seed, + HnswGraph initializerGraph, + Map oldToNewOrdinalMap) + throws IOException { + HnswGraphBuilder hnswGraphBuilder = + new HnswGraphBuilder<>(vectors, vectorEncoding, similarityFunction, M, beamWidth, seed); + hnswGraphBuilder.initializeFromGraph(initializerGraph, oldToNewOrdinalMap); + return hnswGraphBuilder; + } + /** * Reads all the vectors from vector values, builds a graph connecting them by their dense * ordinals, using the given hyperparameter settings, and returns the resulting graph. @@ -110,8 +131,7 @@ private HnswGraphBuilder( // normalization factor for level generation; currently not configurable this.ml = M == 1 ? 1 : 1 / Math.log(1.0 * M); this.random = new SplittableRandom(seed); - int levelOfFirstNode = getRandomGraphLevel(ml, random); - this.hnsw = new OnHeapHnswGraph(M, levelOfFirstNode); + this.hnsw = new OnHeapHnswGraph(M); this.graphSearcher = new HnswGraphSearcher<>( vectorEncoding, @@ -120,6 +140,7 @@ private HnswGraphBuilder( new FixedBitSet(this.vectors.size())); // in scratch we store candidates in reverse order: worse candidates are first scratch = new NeighborArray(Math.max(beamWidth, M + 1), false); + this.initializedNodes = new HashSet<>(); } /** @@ -142,10 +163,64 @@ public OnHeapHnswGraph build(RandomAccessVectorValues vectorsToAdd) throws IO return hnsw; } + /** + * Initializes the graph of this builder. Transfers the nodes and their neighbors from the + * initializer graph into the graph being produced by this builder, mapping ordinals from the + * initializer graph to their new ordinals in this builder's graph. The builder's graph must be + * empty before calling this method. + * + * @param initializerGraph graph used for initialization + * @param oldToNewOrdinalMap map for converting from ordinals in the initializerGraph to this + * builder's graph + */ + private void initializeFromGraph( + HnswGraph initializerGraph, Map oldToNewOrdinalMap) throws IOException { + assert hnsw.size() == 0; + float[] vectorValue = null; + byte[] binaryValue = null; + for (int level = 0; level < initializerGraph.numLevels(); level++) { + HnswGraph.NodesIterator it = initializerGraph.getNodesOnLevel(level); + + while (it.hasNext()) { + int oldOrd = it.nextInt(); + int newOrd = oldToNewOrdinalMap.get(oldOrd); + + hnsw.addNode(level, newOrd); + + if (level == 0) { + initializedNodes.add(newOrd); + } + + switch (this.vectorEncoding) { + case FLOAT32 -> vectorValue = (float[]) vectors.vectorValue(newOrd); + case BYTE -> binaryValue = (byte[]) vectors.vectorValue(newOrd); + } + + NeighborArray newNeighbors = this.hnsw.getNeighbors(level, newOrd); + initializerGraph.seek(level, oldOrd); + for (int oldNeighbor = initializerGraph.nextNeighbor(); + oldNeighbor != NO_MORE_DOCS; + oldNeighbor = initializerGraph.nextNeighbor()) { + int newNeighbor = oldToNewOrdinalMap.get(oldNeighbor); + float score = + switch (this.vectorEncoding) { + case FLOAT32 -> this.similarityFunction.compare( + vectorValue, (float[]) vectorsCopy.vectorValue(newNeighbor)); + case BYTE -> this.similarityFunction.compare( + binaryValue, (byte[]) vectorsCopy.vectorValue(newNeighbor)); + }; + newNeighbors.insertSorted(newNeighbor, score); + } + } + } + } + private void addVectors(RandomAccessVectorValues vectorsToAdd) throws IOException { long start = System.nanoTime(), t = start; - // start at node 1! node 0 is added implicitly, in the constructor - for (int node = 1; node < vectorsToAdd.size(); node++) { + for (int node = 0; node < vectorsToAdd.size(); node++) { + if (initializedNodes.contains(node)) { + continue; + } addGraphNode(node, vectorsToAdd); if ((node % 10000 == 0) && infoStream.isEnabled(HNSW_COMPONENT)) { t = printGraphBuildStatus(node, start, t); @@ -167,6 +242,14 @@ public void addGraphNode(int node, T value) throws IOException { NeighborQueue candidates; final int nodeLevel = getRandomGraphLevel(ml, random); int curMaxLevel = hnsw.numLevels() - 1; + + // If entrynode is -1, then this should finish without adding neighbors + if (hnsw.entryNode() == -1) { + for (int level = nodeLevel; level >= 0; level--) { + hnsw.addNode(level, node); + } + return; + } int[] eps = new int[] {hnsw.entryNode()}; // if a node introduces new levels to the graph, add this new node on new levels diff --git a/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphSearcher.java b/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphSearcher.java index 13a338e0d183..4857d5b9d577 100644 --- a/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphSearcher.java +++ b/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphSearcher.java @@ -101,7 +101,12 @@ public static NeighborQueue search( new NeighborQueue(topK, true), new SparseFixedBitSet(vectors.size())); NeighborQueue results; - int[] eps = new int[] {graph.entryNode()}; + + int initialEp = graph.entryNode(); + if (initialEp == -1) { + return new NeighborQueue(1, true); + } + int[] eps = new int[] {initialEp}; int numVisited = 0; for (int level = graph.numLevels() - 1; level >= 1; level--) { results = graphSearcher.searchLevel(query, 1, level, eps, vectors, graph, null, visitedLimit); diff --git a/lucene/core/src/java/org/apache/lucene/util/hnsw/OnHeapHnswGraph.java b/lucene/core/src/java/org/apache/lucene/util/hnsw/OnHeapHnswGraph.java index 78137c2a6302..9862536de08c 100644 --- a/lucene/core/src/java/org/apache/lucene/util/hnsw/OnHeapHnswGraph.java +++ b/lucene/core/src/java/org/apache/lucene/util/hnsw/OnHeapHnswGraph.java @@ -20,10 +20,9 @@ import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS; import java.util.ArrayList; -import java.util.Arrays; import java.util.List; +import java.util.TreeMap; import org.apache.lucene.util.Accountable; -import org.apache.lucene.util.ArrayUtil; import org.apache.lucene.util.RamUsageEstimator; /** @@ -33,19 +32,20 @@ public final class OnHeapHnswGraph extends HnswGraph implements Accountable { private int numLevels; // the current number of levels in the graph - private int entryNode; // the current graph entry node on the top level + private int entryNode; // the current graph entry node on the top level. -1 if not set - // Nodes by level expressed as the level 0's nodes' ordinals. - // As level 0 contains all nodes, nodesByLevel.get(0) is null. - private final List nodesByLevel; - - // graph is a list of graph levels. - // Each level is represented as List – nodes' connections on this level. + // Level 0 is represented as List – nodes' connections on level 0. // Each entry in the list has the top maxConn/maxConn0 neighbors of a node. The nodes correspond // to vectors // added to HnswBuilder, and the node values are the ordinals of those vectors. // Thus, on all levels, neighbors expressed as the level 0's nodes' ordinals. - private final List> graph; + private final List graphLevel0; + // Represents levels 1-N. Each level is represented with a TreeMap that maps a levels level 0 + // ordinal to its neighbors on that level. All nodes are in level 0, so we do not need to maintain + // it in this list. However, to avoid changing list indexing, we always will make the first + // element + // null. + private final List> graphUpperLevels; private final int nsize; private final int nsize0; @@ -53,24 +53,17 @@ public final class OnHeapHnswGraph extends HnswGraph implements Accountable { private int upto; private NeighborArray cur; - OnHeapHnswGraph(int M, int levelOfFirstNode) { - this.numLevels = levelOfFirstNode + 1; - this.graph = new ArrayList<>(numLevels); - this.entryNode = 0; + OnHeapHnswGraph(int M) { + this.numLevels = 1; // Implicitly start the graph with a single level + this.graphLevel0 = new ArrayList<>(); + this.entryNode = -1; // Entry node should be negative until a node is added // Neighbours' size on upper levels (nsize) and level 0 (nsize0) // We allocate extra space for neighbours, but then prune them to keep allowed maximum this.nsize = M + 1; this.nsize0 = (M * 2 + 1); - for (int l = 0; l < numLevels; l++) { - graph.add(new ArrayList<>()); - graph.get(l).add(new NeighborArray(l == 0 ? nsize0 : nsize, true)); - } - this.nodesByLevel = new ArrayList<>(numLevels); - nodesByLevel.add(null); // we don't need this for 0th level, as it contains all nodes - for (int l = 1; l < numLevels; l++) { - nodesByLevel.add(new int[] {0}); - } + this.graphUpperLevels = new ArrayList<>(numLevels); + graphUpperLevels.add(null); // we don't need this for 0th level, as it contains all nodes } /** @@ -81,49 +74,52 @@ public final class OnHeapHnswGraph extends HnswGraph implements Accountable { */ public NeighborArray getNeighbors(int level, int node) { if (level == 0) { - return graph.get(level).get(node); + return graphLevel0.get(node); } - int nodeIndex = Arrays.binarySearch(nodesByLevel.get(level), 0, graph.get(level).size(), node); - assert nodeIndex >= 0; - return graph.get(level).get(nodeIndex); + TreeMap levelMap = graphUpperLevels.get(level); + assert levelMap.containsKey(node); + return levelMap.get(node); } @Override public int size() { - return graph.get(0).size(); // all nodes are located on the 0th level + return graphLevel0.size(); // all nodes are located on the 0th level } /** - * Add node on the given level + * Add node on the given level. Nodes can be inserted out of order, but it requires that the nodes + * preceded by the node inserted out of order are eventually added. * * @param level level to add a node on * @param node the node to add, represented as an ordinal on the level 0. */ public void addNode(int level, int node) { + if (entryNode == -1) { + entryNode = node; + } + if (level > 0) { // if the new node introduces a new level, add more levels to the graph, // and make this node the graph's new entry point if (level >= numLevels) { for (int i = numLevels; i <= level; i++) { - graph.add(new ArrayList<>()); - nodesByLevel.add(new int[] {node}); + graphUpperLevels.add(new TreeMap<>()); } numLevels = level + 1; entryNode = node; - } else { - // Add this node id to this level's nodes - int[] nodes = nodesByLevel.get(level); - int idx = graph.get(level).size(); - if (idx < nodes.length) { - nodes[idx] = node; - } else { - nodes = ArrayUtil.grow(nodes); - nodes[idx] = node; - nodesByLevel.set(level, nodes); - } + } + + graphUpperLevels.get(level).put(node, new NeighborArray(nsize, true)); + } else { + // Add nodes all the way up to and including "node" in the new graph on level 0. This will + // cause the size of the + // graph to differ from the number of nodes added to the graph. The size of the graph and the + // number of nodes + // added will only be in sync once all nodes from 0...last_node are added into the graph. + while (node >= graphLevel0.size()) { + graphLevel0.add(new NeighborArray(nsize0, true)); } } - graph.get(level).add(new NeighborArray(level == 0 ? nsize0 : nsize, true)); } @Override @@ -164,9 +160,9 @@ public int entryNode() { @Override public NodesIterator getNodesOnLevel(int level) { if (level == 0) { - return new NodesIterator(size()); + return new ArrayNodesIterator(size()); } else { - return new NodesIterator(nodesByLevel.get(level), graph.get(level).size()); + return new CollectionNodesIterator(graphUpperLevels.get(level).keySet()); } } @@ -184,19 +180,26 @@ public long ramBytesUsed() { + Integer.BYTES * 2; long total = 0; for (int l = 0; l < numLevels; l++) { - int numNodesOnLevel = graph.get(l).size(); if (l == 0) { total += - numNodesOnLevel * neighborArrayBytes0 + graphLevel0.size() * neighborArrayBytes0 + RamUsageEstimator.NUM_BYTES_OBJECT_REF; // for graph; } else { + long numNodesOnLevel = graphUpperLevels.get(l).size(); + + // For levels > 0, we represent the graph structure with a tree map. + // A single node in the tree contains 3 references (left root, right root, value) as well + // as an Integer for the key and 1 extra byte for the color of the node (this is actually 1 + // bit, but + // because we do not have that granularity, we set to 1 byte). In addition, we include 1 + // more reference for + // the tree map itself. total += - nodesByLevel.get(l).length * Integer.BYTES - + RamUsageEstimator.NUM_BYTES_ARRAY_HEADER - + RamUsageEstimator.NUM_BYTES_OBJECT_REF; // for nodesByLevel - total += - numNodesOnLevel * neighborArrayBytes - + RamUsageEstimator.NUM_BYTES_OBJECT_REF; // for graph; + numNodesOnLevel * (3L * RamUsageEstimator.NUM_BYTES_OBJECT_REF + Integer.BYTES + 1) + + RamUsageEstimator.NUM_BYTES_OBJECT_REF; + + // Add the size neighbor of each node + total += numNodesOnLevel * neighborArrayBytes; } } return total; diff --git a/lucene/core/src/test/org/apache/lucene/index/TestKnnGraph.java b/lucene/core/src/test/org/apache/lucene/index/TestKnnGraph.java index bba92fab2224..08f089430ba5 100644 --- a/lucene/core/src/test/org/apache/lucene/index/TestKnnGraph.java +++ b/lucene/core/src/test/org/apache/lucene/index/TestKnnGraph.java @@ -48,14 +48,12 @@ import org.apache.lucene.search.TopDocs; import org.apache.lucene.store.Directory; import org.apache.lucene.tests.util.LuceneTestCase; -import org.apache.lucene.util.ArrayUtil; import org.apache.lucene.util.Bits; import org.apache.lucene.util.BytesRef; import org.apache.lucene.util.IOUtils; import org.apache.lucene.util.VectorUtil; import org.apache.lucene.util.hnsw.HnswGraph; import org.apache.lucene.util.hnsw.HnswGraph.NodesIterator; -import org.apache.lucene.util.hnsw.HnswGraphBuilder; import org.junit.After; import org.junit.Before; @@ -179,21 +177,6 @@ public void testMerge() throws Exception { } } - /** - * Verify that we get the *same* graph by indexing one segment as we do by indexing two segments - * and merging. - */ - public void testMergeProducesSameGraph() throws Exception { - long seed = random().nextLong(); - int numDoc = atLeast(100); - int dimension = atLeast(10); - float[][] values = randomVectors(numDoc, dimension); - int mergePoint = random().nextInt(numDoc); - int[][][] mergedGraph = getIndexedGraph(values, mergePoint, seed); - int[][][] singleSegmentGraph = getIndexedGraph(values, -1, seed); - assertGraphEquals(singleSegmentGraph, mergedGraph); - } - /** Test writing and reading of multiple vector fields * */ public void testMultipleVectorFields() throws Exception { int numVectorFields = randomIntBetween(2, 5); @@ -227,52 +210,6 @@ public void testMultipleVectorFields() throws Exception { } } - private void assertGraphEquals(int[][][] expected, int[][][] actual) { - assertEquals("graph sizes differ", expected.length, actual.length); - for (int level = 0; level < expected.length; level++) { - for (int node = 0; node < expected[level].length; node++) { - assertArrayEquals("difference at ord=" + node, expected[level][node], actual[level][node]); - } - } - } - - /** - * Return a naive representation of an HNSW graph as a 3 dimensional array: 1st dim represents a - * graph layer. Each layer contains an array of arrays – a list of nodes and for each node a list - * of the node's neighbours. 2nd dim represents a node on a layer, and contains the node's - * neighbourhood, or {@code null} if a node is not present on this layer. 3rd dim represents - * neighbours of a node. - */ - private int[][][] getIndexedGraph(float[][] values, int mergePoint, long seed) - throws IOException { - HnswGraphBuilder.randSeed = seed; - int[][][] graph; - try (Directory dir = newDirectory()) { - IndexWriterConfig iwc = newIndexWriterConfig(); - iwc.setMergePolicy(new LogDocMergePolicy()); // for predictable segment ordering when merging - iwc.setCodec(codec); // don't use SimpleTextCodec - try (IndexWriter iw = new IndexWriter(dir, iwc)) { - for (int i = 0; i < values.length; i++) { - add(iw, i, values[i]); - if (i == mergePoint) { - // flush proactively to create a segment - iw.flush(); - } - } - iw.forceMerge(1); - } - try (IndexReader reader = DirectoryReader.open(dir)) { - PerFieldKnnVectorsFormat.FieldsReader perFieldReader = - (PerFieldKnnVectorsFormat.FieldsReader) - ((CodecReader) getOnlyLeafReader(reader)).getVectorReader(); - Lucene95HnswVectorsReader vectorReader = - (Lucene95HnswVectorsReader) perFieldReader.getFieldReader(KNN_GRAPH_FIELD); - graph = copyGraph(vectorReader.getGraph(KNN_GRAPH_FIELD)); - } - } - return graph; - } - private float[][] randomVectors(int numDoc, int dimension) { float[][] values = new float[numDoc][]; for (int i = 0; i < numDoc; i++) { @@ -297,27 +234,6 @@ private float[] randomVector(int dimension) { return value; } - int[][][] copyGraph(HnswGraph graphValues) throws IOException { - int[][][] graph = new int[graphValues.numLevels()][][]; - int size = graphValues.size(); - int[] scratch = new int[M * 2]; - - for (int level = 0; level < graphValues.numLevels(); level++) { - NodesIterator nodesItr = graphValues.getNodesOnLevel(level); - graph[level] = new int[size][]; - while (nodesItr.hasNext()) { - int node = nodesItr.nextInt(); - graphValues.seek(level, node); - int n, count = 0; - while ((n = graphValues.nextNeighbor()) != NO_MORE_DOCS) { - scratch[count++] = n; - } - graph[level][node] = ArrayUtil.copyOfSubArray(scratch, 0, count); - } - } - return graph; - } - /** Verify that searching does something reasonable */ public void testSearch() throws Exception { // We can't use dot product here since the vectors are laid out on a grid, not a sphere. diff --git a/lucene/core/src/test/org/apache/lucene/util/hnsw/HnswGraphTestCase.java b/lucene/core/src/test/org/apache/lucene/util/hnsw/HnswGraphTestCase.java index 10ce6b42a78d..80c9c7a93cf4 100644 --- a/lucene/core/src/test/org/apache/lucene/util/hnsw/HnswGraphTestCase.java +++ b/lucene/core/src/test/org/apache/lucene/util/hnsw/HnswGraphTestCase.java @@ -25,10 +25,14 @@ import java.io.IOException; import java.util.ArrayList; import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; import java.util.HashSet; import java.util.List; +import java.util.Map; import java.util.Random; import java.util.Set; +import java.util.stream.Collectors; import org.apache.lucene.codecs.KnnVectorsFormat; import org.apache.lucene.codecs.lucene95.Lucene95Codec; import org.apache.lucene.codecs.lucene95.Lucene95HnswVectorsFormat; @@ -84,6 +88,12 @@ abstract class HnswGraphTestCase extends LuceneTestCase { abstract AbstractMockVectorValues vectorValues(LeafReader reader, String fieldName) throws IOException; + abstract AbstractMockVectorValues vectorValues( + int size, + int dimension, + AbstractMockVectorValues pregeneratedVectorValues, + int pregeneratedOffset); + abstract Field knnVectorField(String name, T vector, VectorSimilarityFunction similarityFunction); abstract RandomAccessVectorValues circularVectorValues(int nDoc); @@ -427,6 +437,238 @@ public void testSearchWithSelectiveAcceptOrds() throws IOException { } } + public void testBuildOnHeapHnswGraphOutOfOrder() throws IOException { + int maxNumLevels = randomIntBetween(2, 10); + int nodeCount = randomIntBetween(1, 100); + + List> nodesPerLevel = new ArrayList<>(); + for (int i = 0; i < maxNumLevels; i++) { + nodesPerLevel.add(new ArrayList<>()); + } + + int numLevels = 0; + for (int currNode = 0; currNode < nodeCount; currNode++) { + int nodeMaxLevel = random().nextInt(1, maxNumLevels + 1); + numLevels = Math.max(numLevels, nodeMaxLevel); + for (int currLevel = 0; currLevel < nodeMaxLevel; currLevel++) { + nodesPerLevel.get(currLevel).add(currNode); + } + } + + OnHeapHnswGraph topDownOrderReversedHnsw = new OnHeapHnswGraph(10); + for (int currLevel = numLevels - 1; currLevel >= 0; currLevel--) { + List currLevelNodes = nodesPerLevel.get(currLevel); + int currLevelNodesSize = currLevelNodes.size(); + for (int currNodeInd = currLevelNodesSize - 1; currNodeInd >= 0; currNodeInd--) { + topDownOrderReversedHnsw.addNode(currLevel, currLevelNodes.get(currNodeInd)); + } + } + + OnHeapHnswGraph bottomUpOrderReversedHnsw = new OnHeapHnswGraph(10); + for (int currLevel = 0; currLevel < numLevels; currLevel++) { + List currLevelNodes = nodesPerLevel.get(currLevel); + int currLevelNodesSize = currLevelNodes.size(); + for (int currNodeInd = currLevelNodesSize - 1; currNodeInd >= 0; currNodeInd--) { + bottomUpOrderReversedHnsw.addNode(currLevel, currLevelNodes.get(currNodeInd)); + } + } + + OnHeapHnswGraph topDownOrderRandomHnsw = new OnHeapHnswGraph(10); + for (int currLevel = numLevels - 1; currLevel >= 0; currLevel--) { + List currLevelNodes = new ArrayList<>(nodesPerLevel.get(currLevel)); + Collections.shuffle(currLevelNodes, random()); + for (Integer currNode : currLevelNodes) { + topDownOrderRandomHnsw.addNode(currLevel, currNode); + } + } + + OnHeapHnswGraph bottomUpExpectedHnsw = new OnHeapHnswGraph(10); + for (int currLevel = 0; currLevel < numLevels; currLevel++) { + for (Integer currNode : nodesPerLevel.get(currLevel)) { + bottomUpExpectedHnsw.addNode(currLevel, currNode); + } + } + + assertEquals(nodeCount, bottomUpExpectedHnsw.getNodesOnLevel(0).size()); + for (Integer node : nodesPerLevel.get(0)) { + assertEquals(0, bottomUpExpectedHnsw.getNeighbors(0, node).size()); + } + + for (int currLevel = 1; currLevel < numLevels; currLevel++) { + NodesIterator nodesIterator = bottomUpExpectedHnsw.getNodesOnLevel(currLevel); + List expectedNodesOnLevel = nodesPerLevel.get(currLevel); + assertEquals(expectedNodesOnLevel.size(), nodesIterator.size()); + for (Integer expectedNode : expectedNodesOnLevel) { + int currentNode = nodesIterator.nextInt(); + assertEquals(expectedNode.intValue(), currentNode); + assertEquals(0, bottomUpExpectedHnsw.getNeighbors(currLevel, currentNode).size()); + } + } + + assertGraphEqual(bottomUpExpectedHnsw, topDownOrderReversedHnsw); + assertGraphEqual(bottomUpExpectedHnsw, bottomUpOrderReversedHnsw); + assertGraphEqual(bottomUpExpectedHnsw, topDownOrderRandomHnsw); + } + + public void testHnswGraphBuilderInitializationFromGraph_withOffsetZero() throws IOException { + int totalSize = atLeast(100); + int initializerSize = random().nextInt(5, totalSize); + int docIdOffset = 0; + int dim = atLeast(10); + long seed = random().nextLong(); + + AbstractMockVectorValues initializerVectors = vectorValues(initializerSize, dim); + HnswGraphBuilder initializerBuilder = + HnswGraphBuilder.create( + initializerVectors, getVectorEncoding(), similarityFunction, 10, 30, seed); + + OnHeapHnswGraph initializerGraph = initializerBuilder.build(initializerVectors.copy()); + AbstractMockVectorValues finalVectorValues = + vectorValues(totalSize, dim, initializerVectors, docIdOffset); + + Map initializerOrdMap = + createOffsetOrdinalMap(initializerSize, finalVectorValues, docIdOffset); + + HnswGraphBuilder finalBuilder = + HnswGraphBuilder.create( + finalVectorValues, + getVectorEncoding(), + similarityFunction, + 10, + 30, + seed, + initializerGraph, + initializerOrdMap); + + // When offset is 0, the graphs should be identical before vectors are added + assertGraphEqual(initializerGraph, finalBuilder.getGraph()); + + OnHeapHnswGraph finalGraph = finalBuilder.build(finalVectorValues.copy()); + assertGraphContainsGraph(finalGraph, initializerGraph, initializerOrdMap); + } + + public void testHnswGraphBuilderInitializationFromGraph_withNonZeroOffset() throws IOException { + int totalSize = atLeast(100); + int initializerSize = random().nextInt(5, totalSize); + int docIdOffset = random().nextInt(1, totalSize - initializerSize + 1); + int dim = atLeast(10); + long seed = random().nextLong(); + + AbstractMockVectorValues initializerVectors = vectorValues(initializerSize, dim); + HnswGraphBuilder initializerBuilder = + HnswGraphBuilder.create( + initializerVectors.copy(), getVectorEncoding(), similarityFunction, 10, 30, seed); + OnHeapHnswGraph initializerGraph = initializerBuilder.build(initializerVectors.copy()); + AbstractMockVectorValues finalVectorValues = + vectorValues(totalSize, dim, initializerVectors.copy(), docIdOffset); + Map initializerOrdMap = + createOffsetOrdinalMap(initializerSize, finalVectorValues.copy(), docIdOffset); + + HnswGraphBuilder finalBuilder = + HnswGraphBuilder.create( + finalVectorValues, + getVectorEncoding(), + similarityFunction, + 10, + 30, + seed, + initializerGraph, + initializerOrdMap); + + assertGraphInitializedFromGraph(finalBuilder.getGraph(), initializerGraph, initializerOrdMap); + + // Confirm that the graph is appropriately constructed by checking that the nodes in the old + // graph are present in the levels of the new graph + OnHeapHnswGraph finalGraph = finalBuilder.build(finalVectorValues.copy()); + assertGraphContainsGraph(finalGraph, initializerGraph, initializerOrdMap); + } + + private void assertGraphContainsGraph( + HnswGraph g, HnswGraph h, Map oldToNewOrdMap) throws IOException { + for (int i = 0; i < h.numLevels(); i++) { + int[] finalGraphNodesOnLevel = nodesIteratorToArray(g.getNodesOnLevel(i)); + int[] initializerGraphNodesOnLevel = + mapArrayAndSort(nodesIteratorToArray(h.getNodesOnLevel(i)), oldToNewOrdMap); + int overlap = computeOverlap(finalGraphNodesOnLevel, initializerGraphNodesOnLevel); + assertEquals(initializerGraphNodesOnLevel.length, overlap); + } + } + + private void assertGraphInitializedFromGraph( + HnswGraph g, HnswGraph h, Map oldToNewOrdMap) throws IOException { + assertEquals("the number of levels in the graphs are different!", g.numLevels(), h.numLevels()); + // Confirm that the size of the new graph includes all nodes up to an including the max new + // ordinal in the old to + // new ordinal mapping + assertEquals( + "the number of nodes in the graphs are different!", + g.size(), + Collections.max(oldToNewOrdMap.values()) + 1); + + // assert the nodes from the previous graph are successfully to levels > 0 in the new graph + for (int level = 1; level < g.numLevels(); level++) { + NodesIterator nodesOnLevel = g.getNodesOnLevel(level); + NodesIterator nodesOnLevel2 = h.getNodesOnLevel(level); + while (nodesOnLevel.hasNext() && nodesOnLevel2.hasNext()) { + int node = nodesOnLevel.nextInt(); + int node2 = oldToNewOrdMap.get(nodesOnLevel2.nextInt()); + assertEquals("nodes in the graphs are different", node, node2); + } + } + + // assert that the neighbors from the old graph are successfully transferred to the new graph + for (int level = 0; level < g.numLevels(); level++) { + NodesIterator nodesOnLevel = h.getNodesOnLevel(level); + while (nodesOnLevel.hasNext()) { + int node = nodesOnLevel.nextInt(); + g.seek(level, oldToNewOrdMap.get(node)); + h.seek(level, node); + assertEquals( + "arcs differ for node " + node, + getNeighborNodes(g), + getNeighborNodes(h).stream().map(oldToNewOrdMap::get).collect(Collectors.toSet())); + } + } + } + + private Map createOffsetOrdinalMap( + int docIdSize, AbstractMockVectorValues totalVectorValues, int docIdOffset) { + // Compute the offset for the ordinal map to be the number of non-null vectors in the total + // vector values + // before the docIdOffset + int ordinalOffset = 0; + while (totalVectorValues.nextDoc() < docIdOffset) { + ordinalOffset++; + } + + Map offsetOrdinalMap = new HashMap<>(); + for (int curr = 0; + totalVectorValues.docID() < docIdOffset + docIdSize; + totalVectorValues.nextDoc()) { + offsetOrdinalMap.put(curr, ordinalOffset + curr++); + } + + return offsetOrdinalMap; + } + + private int[] nodesIteratorToArray(NodesIterator nodesIterator) { + int[] arr = new int[nodesIterator.size()]; + int i = 0; + while (nodesIterator.hasNext()) { + arr[i++] = nodesIterator.nextInt(); + } + return arr; + } + + private int[] mapArrayAndSort(int[] arr, Map map) { + int[] mappedA = new int[arr.length]; + for (int i = 0; i < arr.length; i++) { + mappedA[i] = map.get(arr[i]); + } + Arrays.sort(mappedA); + return mappedA; + } + @SuppressWarnings("unchecked") public void testVisitedLimit() throws IOException { int nDoc = 500; @@ -531,8 +773,8 @@ public void testDiversity() throws IOException { HnswGraphBuilder.create( vectors, getVectorEncoding(), similarityFunction, 2, 10, random().nextInt()); // node 0 is added by the builder constructor - // builder.addGraphNode(vectors.vectorValue(0)); RandomAccessVectorValues vectorsCopy = vectors.copy(); + builder.addGraphNode(0, vectorsCopy); builder.addGraphNode(1, vectorsCopy); builder.addGraphNode(2, vectorsCopy); // now every node has tried to attach every other node as a neighbor, but @@ -586,9 +828,8 @@ public void testDiversityFallback() throws IOException { HnswGraphBuilder builder = HnswGraphBuilder.create( vectors, getVectorEncoding(), similarityFunction, 1, 10, random().nextInt()); - // node 0 is added by the builder constructor - // builder.addGraphNode(vectors.vectorValue(0)); RandomAccessVectorValues vectorsCopy = vectors.copy(); + builder.addGraphNode(0, vectorsCopy); builder.addGraphNode(1, vectorsCopy); builder.addGraphNode(2, vectorsCopy); assertLevel0Neighbors(builder.hnsw, 0, 1, 2); @@ -619,9 +860,8 @@ public void testDiversity3d() throws IOException { HnswGraphBuilder builder = HnswGraphBuilder.create( vectors, getVectorEncoding(), similarityFunction, 1, 10, random().nextInt()); - // node 0 is added by the builder constructor - // builder.addGraphNode(vectors.vectorValue(0)); RandomAccessVectorValues vectorsCopy = vectors.copy(); + builder.addGraphNode(0, vectorsCopy); builder.addGraphNode(1, vectorsCopy); builder.addGraphNode(2, vectorsCopy); assertLevel0Neighbors(builder.hnsw, 0, 1, 2); diff --git a/lucene/core/src/test/org/apache/lucene/util/hnsw/TestHnswByteVectorGraph.java b/lucene/core/src/test/org/apache/lucene/util/hnsw/TestHnswByteVectorGraph.java index 258864ade7dd..3a2d92ff92a6 100644 --- a/lucene/core/src/test/org/apache/lucene/util/hnsw/TestHnswByteVectorGraph.java +++ b/lucene/core/src/test/org/apache/lucene/util/hnsw/TestHnswByteVectorGraph.java @@ -85,6 +85,34 @@ AbstractMockVectorValues vectorValues(float[][] values) { return MockByteVectorValues.fromValues(bValues); } + @Override + AbstractMockVectorValues vectorValues( + int size, + int dimension, + AbstractMockVectorValues pregeneratedVectorValues, + int pregeneratedOffset) { + byte[][] vectors = new byte[size][]; + byte[][] randomVectors = + createRandomByteVectors(size - pregeneratedVectorValues.values.length, dimension, random()); + + for (int i = 0; i < pregeneratedOffset; i++) { + vectors[i] = randomVectors[i]; + } + + int currentDoc; + while ((currentDoc = pregeneratedVectorValues.nextDoc()) != NO_MORE_DOCS) { + vectors[pregeneratedOffset + currentDoc] = pregeneratedVectorValues.values[currentDoc]; + } + + for (int i = pregeneratedOffset + pregeneratedVectorValues.values.length; + i < vectors.length; + i++) { + vectors[i] = randomVectors[i - pregeneratedVectorValues.values.length]; + } + + return MockByteVectorValues.fromValues(vectors); + } + @Override AbstractMockVectorValues vectorValues(LeafReader reader, String fieldName) throws IOException { diff --git a/lucene/core/src/test/org/apache/lucene/util/hnsw/TestHnswFloatVectorGraph.java b/lucene/core/src/test/org/apache/lucene/util/hnsw/TestHnswFloatVectorGraph.java index 16f2e7330e27..5dda5bf0a838 100644 --- a/lucene/core/src/test/org/apache/lucene/util/hnsw/TestHnswFloatVectorGraph.java +++ b/lucene/core/src/test/org/apache/lucene/util/hnsw/TestHnswFloatVectorGraph.java @@ -79,6 +79,35 @@ AbstractMockVectorValues vectorValues(LeafReader reader, String fieldNa return MockVectorValues.fromValues(vectors); } + @Override + AbstractMockVectorValues vectorValues( + int size, + int dimension, + AbstractMockVectorValues pregeneratedVectorValues, + int pregeneratedOffset) { + float[][] vectors = new float[size][]; + float[][] randomVectors = + createRandomFloatVectors( + size - pregeneratedVectorValues.values.length, dimension, random()); + + for (int i = 0; i < pregeneratedOffset; i++) { + vectors[i] = randomVectors[i]; + } + + int currentDoc; + while ((currentDoc = pregeneratedVectorValues.nextDoc()) != NO_MORE_DOCS) { + vectors[pregeneratedOffset + currentDoc] = pregeneratedVectorValues.values[currentDoc]; + } + + for (int i = pregeneratedOffset + pregeneratedVectorValues.values.length; + i < vectors.length; + i++) { + vectors[i] = randomVectors[i - pregeneratedVectorValues.values.length]; + } + + return MockVectorValues.fromValues(vectors); + } + @Override Field knnVectorField(String name, float[] vector, VectorSimilarityFunction similarityFunction) { return new KnnFloatVectorField(name, vector, similarityFunction);