From 776149f0f6964bbc72ad2d292d1bfe770f82ba45 Mon Sep 17 00:00:00 2001 From: John Mazanec Date: Tue, 7 Feb 2023 11:42:03 -0800 Subject: [PATCH] Reuse HNSW graph for intialization during merge (#12050) * Remove implicit addition of vector 0 Removes logic to add 0 vector implicitly. This is in preparation for adding nodes from other graphs to initialize a new graph. Having the implicit addition of node 0 complicates this logic. Signed-off-by: John Mazanec * Enable out of order insertion of nodes in hnsw Enables nodes to be added into OnHeapHnswGraph in out of order fashion. To do so, additional operations have to be taken to resort the nodesByLevel array. Optimizations have been made to avoid sorting whenever possible. Signed-off-by: John Mazanec * Add ability to initialize from graph Adds method to initialize an HNSWGraphBuilder from another HNSWGraph. Initialization can only happen when the builder's graph is empty. Signed-off-by: John Mazanec * Utilize merge with graph init in HNSWWriter Uses HNSWGraphBuilder initialization from graph functionality in Lucene95HnswVectorsWriter. Selects the largest graph to initialize the new graph produced by the HNSWGraphBuilder for merge. Signed-off-by: John Mazanec * Minor modifications to Lucene95HnswVectorsWriter Signed-off-by: John Mazanec * Use TreeMap for graph structure for levels > 0 Refactors OnHeapHnswGraph to use TreeMap to represent graph structure of levels greater than 0. Refactors NodesIterator to support set representation of nodes. Signed-off-by: John Mazanec * Refactor initializer to be in static create method Refeactors initialization from graph to be accessible via a create static method in HnswGraphBuilder. Signed-off-by: John Mazanec * Address review comments Signed-off-by: John Mazanec * Add change log entry Signed-off-by: John Mazanec * Remove empty iterator for neighborqueue Signed-off-by: John Mazanec --------- Signed-off-by: John Mazanec --- lucene/CHANGES.txt | 2 + .../lucene91/Lucene91HnswVectorsReader.java | 4 +- .../lucene91/Lucene91OnHeapHnswGraph.java | 4 +- .../lucene92/Lucene92HnswVectorsReader.java | 4 +- .../lucene94/Lucene94HnswVectorsReader.java | 4 +- .../lucene94/Lucene94HnswVectorsWriter.java | 7 +- .../lucene95/Lucene95HnswVectorsReader.java | 4 +- .../lucene95/Lucene95HnswVectorsWriter.java | 212 +++++++++++++-- .../apache/lucene/util/hnsw/HnswGraph.java | 90 +++++-- .../lucene/util/hnsw/HnswGraphBuilder.java | 91 ++++++- .../lucene/util/hnsw/HnswGraphSearcher.java | 7 +- .../lucene/util/hnsw/OnHeapHnswGraph.java | 109 ++++---- .../org/apache/lucene/index/TestKnnGraph.java | 84 ------ .../lucene/util/hnsw/HnswGraphTestCase.java | 250 +++++++++++++++++- .../util/hnsw/TestHnswByteVectorGraph.java | 28 ++ .../util/hnsw/TestHnswFloatVectorGraph.java | 29 ++ 16 files changed, 729 insertions(+), 200 deletions(-) diff --git a/lucene/CHANGES.txt b/lucene/CHANGES.txt index 1bc6cfe3ccd9..a55713e33ef6 100644 --- a/lucene/CHANGES.txt +++ b/lucene/CHANGES.txt @@ -133,6 +133,8 @@ Optimizations * GITHUB#12128, GITHUB#12133: Speed up docvalues set query by making use of sortedness. (Robert Muir, Uwe Schindler) +* GITHUB#12050: Reuse HNSW graph for intialization during merge (Jack Mazanec) + Bug Fixes --------------------- (No changes) diff --git a/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene91/Lucene91HnswVectorsReader.java b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene91/Lucene91HnswVectorsReader.java index 2b46cb498ab1..8e66d0755878 100644 --- a/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene91/Lucene91HnswVectorsReader.java +++ b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene91/Lucene91HnswVectorsReader.java @@ -561,9 +561,9 @@ public int entryNode() { @Override public NodesIterator getNodesOnLevel(int level) { if (level == 0) { - return new NodesIterator(size()); + return new ArrayNodesIterator(size()); } else { - return new NodesIterator(nodesByLevel[level], nodesByLevel[level].length); + return new ArrayNodesIterator(nodesByLevel[level], nodesByLevel[level].length); } } } diff --git a/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene91/Lucene91OnHeapHnswGraph.java b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene91/Lucene91OnHeapHnswGraph.java index 2d3ef582b472..e762e016bbff 100644 --- a/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene91/Lucene91OnHeapHnswGraph.java +++ b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene91/Lucene91OnHeapHnswGraph.java @@ -163,9 +163,9 @@ public int entryNode() { @Override public NodesIterator getNodesOnLevel(int level) { if (level == 0) { - return new NodesIterator(size()); + return new ArrayNodesIterator(size()); } else { - return new NodesIterator(nodesByLevel.get(level), graph.get(level).size()); + return new ArrayNodesIterator(nodesByLevel.get(level), graph.get(level).size()); } } } diff --git a/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene92/Lucene92HnswVectorsReader.java b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene92/Lucene92HnswVectorsReader.java index 8fb1b3a92d36..df51972e8ddf 100644 --- a/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene92/Lucene92HnswVectorsReader.java +++ b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene92/Lucene92HnswVectorsReader.java @@ -457,9 +457,9 @@ public int entryNode() { @Override public NodesIterator getNodesOnLevel(int level) { if (level == 0) { - return new NodesIterator(size()); + return new ArrayNodesIterator(size()); } else { - return new NodesIterator(nodesByLevel[level], nodesByLevel[level].length); + return new ArrayNodesIterator(nodesByLevel[level], nodesByLevel[level].length); } } } diff --git a/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene94/Lucene94HnswVectorsReader.java b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene94/Lucene94HnswVectorsReader.java index 51a9aa23c271..ee6472ab2d6c 100644 --- a/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene94/Lucene94HnswVectorsReader.java +++ b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene94/Lucene94HnswVectorsReader.java @@ -533,9 +533,9 @@ public int entryNode() { @Override public NodesIterator getNodesOnLevel(int level) { if (level == 0) { - return new NodesIterator(size()); + return new ArrayNodesIterator(size()); } else { - return new NodesIterator(nodesByLevel[level], nodesByLevel[level].length); + return new ArrayNodesIterator(nodesByLevel[level], nodesByLevel[level].length); } } } diff --git a/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene94/Lucene94HnswVectorsWriter.java b/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene94/Lucene94HnswVectorsWriter.java index 2e125c824629..f6f378027603 100644 --- a/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene94/Lucene94HnswVectorsWriter.java +++ b/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene94/Lucene94HnswVectorsWriter.java @@ -345,7 +345,7 @@ public NodesIterator getNodesOnLevel(int level) { if (level == 0) { return graph.getNodesOnLevel(0); } else { - return new NodesIterator(nodesByLevel.get(level), nodesByLevel.get(level).length); + return new ArrayNodesIterator(nodesByLevel.get(level), nodesByLevel.get(level).length); } } }; @@ -687,10 +687,7 @@ public void addValue(int docID, Object value) throws IOException { assert docID > lastDocID; docsWithField.add(docID); vectors.add(copyValue(vectorValue)); - if (node > 0) { - // start at node 1! node 0 is added implicitly, in the constructor - hnswGraphBuilder.addGraphNode(node, vectorValue); - } + hnswGraphBuilder.addGraphNode(node, vectorValue); node++; lastDocID = docID; } diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene95/Lucene95HnswVectorsReader.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene95/Lucene95HnswVectorsReader.java index 8d140b1fa1e0..185a472b5f48 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/lucene95/Lucene95HnswVectorsReader.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene95/Lucene95HnswVectorsReader.java @@ -573,9 +573,9 @@ public int entryNode() throws IOException { @Override public NodesIterator getNodesOnLevel(int level) { if (level == 0) { - return new NodesIterator(size()); + return new ArrayNodesIterator(size()); } else { - return new NodesIterator(nodesByLevel[level], nodesByLevel[level].length); + return new ArrayNodesIterator(nodesByLevel[level], nodesByLevel[level].length); } } } diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene95/Lucene95HnswVectorsWriter.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene95/Lucene95HnswVectorsWriter.java index e9180d63d33d..bf0b79807f06 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/lucene95/Lucene95HnswVectorsWriter.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene95/Lucene95HnswVectorsWriter.java @@ -25,11 +25,16 @@ import java.nio.ByteOrder; import java.util.ArrayList; import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; import java.util.List; +import java.util.Map; import org.apache.lucene.codecs.CodecUtil; import org.apache.lucene.codecs.KnnFieldVectorsWriter; +import org.apache.lucene.codecs.KnnVectorsReader; import org.apache.lucene.codecs.KnnVectorsWriter; import org.apache.lucene.codecs.lucene90.IndexedDISI; +import org.apache.lucene.codecs.perfield.PerFieldKnnVectorsFormat; import org.apache.lucene.index.*; import org.apache.lucene.index.Sorter; import org.apache.lucene.search.DocIdSetIterator; @@ -357,7 +362,7 @@ public NodesIterator getNodesOnLevel(int level) { if (level == 0) { return graph.getNodesOnLevel(0); } else { - return new NodesIterator(nodesByLevel.get(level), nodesByLevel.get(level).length); + return new ArrayNodesIterator(nodesByLevel.get(level), nodesByLevel.get(level).length); } } }; @@ -424,6 +429,7 @@ public void mergeOneField(FieldInfo fieldInfo, MergeState mergeState) throws IOE int[][] vectorIndexNodeOffsets = null; if (docsWithField.cardinality() != 0) { // build graph + int initializerIndex = selectGraphForInitialization(mergeState, fieldInfo); graph = switch (fieldInfo.getVectorEncoding()) { case BYTE -> { @@ -434,13 +440,7 @@ public void mergeOneField(FieldInfo fieldInfo, MergeState mergeState) throws IOE vectorDataInput, byteSize); HnswGraphBuilder hnswGraphBuilder = - HnswGraphBuilder.create( - vectorValues, - fieldInfo.getVectorEncoding(), - fieldInfo.getVectorSimilarityFunction(), - M, - beamWidth, - HnswGraphBuilder.randSeed); + createHnswGraphBuilder(mergeState, fieldInfo, vectorValues, initializerIndex); hnswGraphBuilder.setInfoStream(segmentWriteState.infoStream); yield hnswGraphBuilder.build(vectorValues.copy()); } @@ -452,13 +452,7 @@ public void mergeOneField(FieldInfo fieldInfo, MergeState mergeState) throws IOE vectorDataInput, byteSize); HnswGraphBuilder hnswGraphBuilder = - HnswGraphBuilder.create( - vectorValues, - fieldInfo.getVectorEncoding(), - fieldInfo.getVectorSimilarityFunction(), - M, - beamWidth, - HnswGraphBuilder.randSeed); + createHnswGraphBuilder(mergeState, fieldInfo, vectorValues, initializerIndex); hnswGraphBuilder.setInfoStream(segmentWriteState.infoStream); yield hnswGraphBuilder.build(vectorValues.copy()); } @@ -489,6 +483,189 @@ public void mergeOneField(FieldInfo fieldInfo, MergeState mergeState) throws IOE } } + private HnswGraphBuilder createHnswGraphBuilder( + MergeState mergeState, + FieldInfo fieldInfo, + RandomAccessVectorValues floatVectorValues, + int initializerIndex) + throws IOException { + if (initializerIndex == -1) { + return HnswGraphBuilder.create( + floatVectorValues, + fieldInfo.getVectorEncoding(), + fieldInfo.getVectorSimilarityFunction(), + M, + beamWidth, + HnswGraphBuilder.randSeed); + } + + HnswGraph initializerGraph = + getHnswGraphFromReader(fieldInfo.name, mergeState.knnVectorsReaders[initializerIndex]); + Map ordinalMapper = + getOldToNewOrdinalMap(mergeState, fieldInfo, initializerIndex); + return HnswGraphBuilder.create( + floatVectorValues, + fieldInfo.getVectorEncoding(), + fieldInfo.getVectorSimilarityFunction(), + M, + beamWidth, + HnswGraphBuilder.randSeed, + initializerGraph, + ordinalMapper); + } + + private int selectGraphForInitialization(MergeState mergeState, FieldInfo fieldInfo) + throws IOException { + // Find the KnnVectorReader with the most docs that meets the following criteria: + // 1. Does not contain any deleted docs + // 2. Is a Lucene95HnswVectorsReader/PerFieldKnnVectorReader + // If no readers exist that meet this criteria, return -1. If they do, return their index in + // merge state + int maxCandidateVectorCount = 0; + int initializerIndex = -1; + + for (int i = 0; i < mergeState.liveDocs.length; i++) { + KnnVectorsReader currKnnVectorsReader = mergeState.knnVectorsReaders[i]; + if (mergeState.knnVectorsReaders[i] + instanceof PerFieldKnnVectorsFormat.FieldsReader candidateReader) { + currKnnVectorsReader = candidateReader.getFieldReader(fieldInfo.name); + } + + if (!allMatch(mergeState.liveDocs[i]) + || !(currKnnVectorsReader instanceof Lucene95HnswVectorsReader candidateReader)) { + continue; + } + + int candidateVectorCount = 0; + switch (fieldInfo.getVectorEncoding()) { + case BYTE -> { + ByteVectorValues byteVectorValues = candidateReader.getByteVectorValues(fieldInfo.name); + if (byteVectorValues == null) { + continue; + } + candidateVectorCount = byteVectorValues.size(); + } + case FLOAT32 -> { + FloatVectorValues vectorValues = candidateReader.getFloatVectorValues(fieldInfo.name); + if (vectorValues == null) { + continue; + } + candidateVectorCount = vectorValues.size(); + } + } + + if (candidateVectorCount > maxCandidateVectorCount) { + maxCandidateVectorCount = candidateVectorCount; + initializerIndex = i; + } + } + return initializerIndex; + } + + private HnswGraph getHnswGraphFromReader(String fieldName, KnnVectorsReader knnVectorsReader) + throws IOException { + if (knnVectorsReader instanceof PerFieldKnnVectorsFormat.FieldsReader perFieldReader + && perFieldReader.getFieldReader(fieldName) + instanceof Lucene95HnswVectorsReader fieldReader) { + return fieldReader.getGraph(fieldName); + } + + if (knnVectorsReader instanceof Lucene95HnswVectorsReader) { + return ((Lucene95HnswVectorsReader) knnVectorsReader).getGraph(fieldName); + } + + // We should not reach here because knnVectorsReader's type is checked in + // selectGraphForInitialization + throw new IllegalArgumentException( + "Invalid KnnVectorsReader type for field: " + + fieldName + + ". Must be Lucene95HnswVectorsReader or newer"); + } + + private Map getOldToNewOrdinalMap( + MergeState mergeState, FieldInfo fieldInfo, int initializerIndex) throws IOException { + + DocIdSetIterator initializerIterator = null; + + switch (fieldInfo.getVectorEncoding()) { + case BYTE -> initializerIterator = + mergeState.knnVectorsReaders[initializerIndex].getByteVectorValues(fieldInfo.name); + case FLOAT32 -> initializerIterator = + mergeState.knnVectorsReaders[initializerIndex].getFloatVectorValues(fieldInfo.name); + } + + MergeState.DocMap initializerDocMap = mergeState.docMaps[initializerIndex]; + + Map newIdToOldOrdinal = new HashMap<>(); + int oldOrd = 0; + int maxNewDocID = -1; + for (int oldId = initializerIterator.nextDoc(); + oldId != NO_MORE_DOCS; + oldId = initializerIterator.nextDoc()) { + if (isCurrentVectorNull(initializerIterator)) { + continue; + } + int newId = initializerDocMap.get(oldId); + maxNewDocID = Math.max(newId, maxNewDocID); + newIdToOldOrdinal.put(newId, oldOrd); + oldOrd++; + } + + if (maxNewDocID == -1) { + return Collections.emptyMap(); + } + + Map oldToNewOrdinalMap = new HashMap<>(); + + DocIdSetIterator vectorIterator = null; + switch (fieldInfo.getVectorEncoding()) { + case BYTE -> vectorIterator = MergedVectorValues.mergeByteVectorValues(fieldInfo, mergeState); + case FLOAT32 -> vectorIterator = + MergedVectorValues.mergeFloatVectorValues(fieldInfo, mergeState); + } + + int newOrd = 0; + for (int newDocId = vectorIterator.nextDoc(); + newDocId <= maxNewDocID; + newDocId = vectorIterator.nextDoc()) { + if (isCurrentVectorNull(vectorIterator)) { + continue; + } + + if (newIdToOldOrdinal.containsKey(newDocId)) { + oldToNewOrdinalMap.put(newIdToOldOrdinal.get(newDocId), newOrd); + } + newOrd++; + } + + return oldToNewOrdinalMap; + } + + private boolean isCurrentVectorNull(DocIdSetIterator docIdSetIterator) throws IOException { + if (docIdSetIterator instanceof FloatVectorValues) { + return ((FloatVectorValues) docIdSetIterator).vectorValue() == null; + } + + if (docIdSetIterator instanceof ByteVectorValues) { + return ((ByteVectorValues) docIdSetIterator).vectorValue() == null; + } + + return true; + } + + private boolean allMatch(Bits bits) { + if (bits == null) { + return true; + } + + for (int i = 0; i < bits.length(); i++) { + if (!bits.get(i)) { + return false; + } + } + return true; + } + /** * @param graph Write the graph in a compressed format * @return The non-cumulative offsets for the nodes. Should be used to create cumulative offsets. @@ -735,10 +912,7 @@ public void addValue(int docID, T vectorValue) throws IOException { assert docID > lastDocID; docsWithField.add(docID); vectors.add(copyValue(vectorValue)); - if (node > 0) { - // start at node 1! node 0 is added implicitly, in the constructor - hnswGraphBuilder.addGraphNode(node, vectorValue); - } + hnswGraphBuilder.addGraphNode(node, vectorValue); node++; lastDocID = docID; } diff --git a/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraph.java b/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraph.java index fc7b0be82fbb..9086ab55d2eb 100644 --- a/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraph.java +++ b/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraph.java @@ -20,6 +20,8 @@ import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS; import java.io.IOException; +import java.util.Collection; +import java.util.Iterator; import java.util.NoSuchElementException; import java.util.PrimitiveIterator; import org.apache.lucene.index.FloatVectorValues; @@ -115,7 +117,7 @@ public int entryNode() { @Override public NodesIterator getNodesOnLevel(int level) { - return NodesIterator.EMPTY; + return ArrayNodesIterator.EMPTY; } }; @@ -123,33 +125,50 @@ public NodesIterator getNodesOnLevel(int level) { * Iterator over the graph nodes on a certain level, Iterator also provides the size – the total * number of nodes to be iterated over. */ - public static final class NodesIterator implements PrimitiveIterator.OfInt { - static NodesIterator EMPTY = new NodesIterator(0); - - private final int[] nodes; - private final int size; - int cur = 0; - - /** Constructor for iterator based on the nodes array up to the size */ - public NodesIterator(int[] nodes, int size) { - assert nodes != null; - assert size <= nodes.length; - this.nodes = nodes; - this.size = size; - } + public abstract static class NodesIterator implements PrimitiveIterator.OfInt { + protected final int size; /** Constructor for iterator based on the size */ public NodesIterator(int size) { - this.nodes = null; this.size = size; } + /** The number of elements in this iterator * */ + public int size() { + return size; + } + /** * Consume integers from the iterator and place them into the `dest` array. * * @param dest where to put the integers * @return The number of integers written to `dest` */ + public abstract int consume(int[] dest); + } + + /** NodesIterator that accepts nodes as an integer array. */ + public static class ArrayNodesIterator extends NodesIterator { + static NodesIterator EMPTY = new ArrayNodesIterator(0); + + private final int[] nodes; + private int cur = 0; + + /** Constructor for iterator based on integer array representing nodes */ + public ArrayNodesIterator(int[] nodes, int size) { + super(size); + assert nodes != null; + assert size <= nodes.length; + this.nodes = nodes; + } + + /** Constructor for iterator based on the size */ + public ArrayNodesIterator(int size) { + super(size); + this.nodes = null; + } + + @Override public int consume(int[] dest) { if (hasNext() == false) { throw new NoSuchElementException(); @@ -182,10 +201,43 @@ public int nextInt() { public boolean hasNext() { return cur < size; } + } - /** The number of elements in this iterator * */ - public int size() { - return size; + /** Nodes iterator based on set representation of nodes. */ + public static class CollectionNodesIterator extends NodesIterator { + Iterator nodes; + + /** Constructor for iterator based on collection representing nodes */ + public CollectionNodesIterator(Collection nodes) { + super(nodes.size()); + this.nodes = nodes.iterator(); + } + + @Override + public int consume(int[] dest) { + if (hasNext() == false) { + throw new NoSuchElementException(); + } + + int destIndex = 0; + while (hasNext() && destIndex < dest.length) { + dest[destIndex++] = nextInt(); + } + + return destIndex; + } + + @Override + public int nextInt() { + if (hasNext() == false) { + throw new NoSuchElementException(); + } + return nodes.next(); + } + + @Override + public boolean hasNext() { + return nodes.hasNext(); } } } diff --git a/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphBuilder.java b/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphBuilder.java index e29329932610..9f1e6c505254 100644 --- a/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphBuilder.java +++ b/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphBuilder.java @@ -18,10 +18,14 @@ package org.apache.lucene.util.hnsw; import static java.lang.Math.log; +import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS; import java.io.IOException; +import java.util.HashSet; import java.util.Locale; +import java.util.Map; import java.util.Objects; +import java.util.Set; import java.util.SplittableRandom; import java.util.concurrent.TimeUnit; import org.apache.lucene.index.VectorEncoding; @@ -63,6 +67,7 @@ public final class HnswGraphBuilder { // we need two sources of vectors in order to perform diversity check comparisons without // colliding private final RandomAccessVectorValues vectorsCopy; + private final Set initializedNodes; public static HnswGraphBuilder create( RandomAccessVectorValues vectors, @@ -75,6 +80,22 @@ public static HnswGraphBuilder create( return new HnswGraphBuilder<>(vectors, vectorEncoding, similarityFunction, M, beamWidth, seed); } + public static HnswGraphBuilder create( + RandomAccessVectorValues vectors, + VectorEncoding vectorEncoding, + VectorSimilarityFunction similarityFunction, + int M, + int beamWidth, + long seed, + HnswGraph initializerGraph, + Map oldToNewOrdinalMap) + throws IOException { + HnswGraphBuilder hnswGraphBuilder = + new HnswGraphBuilder<>(vectors, vectorEncoding, similarityFunction, M, beamWidth, seed); + hnswGraphBuilder.initializeFromGraph(initializerGraph, oldToNewOrdinalMap); + return hnswGraphBuilder; + } + /** * Reads all the vectors from vector values, builds a graph connecting them by their dense * ordinals, using the given hyperparameter settings, and returns the resulting graph. @@ -110,8 +131,7 @@ private HnswGraphBuilder( // normalization factor for level generation; currently not configurable this.ml = M == 1 ? 1 : 1 / Math.log(1.0 * M); this.random = new SplittableRandom(seed); - int levelOfFirstNode = getRandomGraphLevel(ml, random); - this.hnsw = new OnHeapHnswGraph(M, levelOfFirstNode); + this.hnsw = new OnHeapHnswGraph(M); this.graphSearcher = new HnswGraphSearcher<>( vectorEncoding, @@ -120,6 +140,7 @@ private HnswGraphBuilder( new FixedBitSet(this.vectors.size())); // in scratch we store candidates in reverse order: worse candidates are first scratch = new NeighborArray(Math.max(beamWidth, M + 1), false); + this.initializedNodes = new HashSet<>(); } /** @@ -142,10 +163,64 @@ public OnHeapHnswGraph build(RandomAccessVectorValues vectorsToAdd) throws IO return hnsw; } + /** + * Initializes the graph of this builder. Transfers the nodes and their neighbors from the + * initializer graph into the graph being produced by this builder, mapping ordinals from the + * initializer graph to their new ordinals in this builder's graph. The builder's graph must be + * empty before calling this method. + * + * @param initializerGraph graph used for initialization + * @param oldToNewOrdinalMap map for converting from ordinals in the initializerGraph to this + * builder's graph + */ + private void initializeFromGraph( + HnswGraph initializerGraph, Map oldToNewOrdinalMap) throws IOException { + assert hnsw.size() == 0; + float[] vectorValue = null; + byte[] binaryValue = null; + for (int level = 0; level < initializerGraph.numLevels(); level++) { + HnswGraph.NodesIterator it = initializerGraph.getNodesOnLevel(level); + + while (it.hasNext()) { + int oldOrd = it.nextInt(); + int newOrd = oldToNewOrdinalMap.get(oldOrd); + + hnsw.addNode(level, newOrd); + + if (level == 0) { + initializedNodes.add(newOrd); + } + + switch (this.vectorEncoding) { + case FLOAT32 -> vectorValue = (float[]) vectors.vectorValue(newOrd); + case BYTE -> binaryValue = (byte[]) vectors.vectorValue(newOrd); + } + + NeighborArray newNeighbors = this.hnsw.getNeighbors(level, newOrd); + initializerGraph.seek(level, oldOrd); + for (int oldNeighbor = initializerGraph.nextNeighbor(); + oldNeighbor != NO_MORE_DOCS; + oldNeighbor = initializerGraph.nextNeighbor()) { + int newNeighbor = oldToNewOrdinalMap.get(oldNeighbor); + float score = + switch (this.vectorEncoding) { + case FLOAT32 -> this.similarityFunction.compare( + vectorValue, (float[]) vectorsCopy.vectorValue(newNeighbor)); + case BYTE -> this.similarityFunction.compare( + binaryValue, (byte[]) vectorsCopy.vectorValue(newNeighbor)); + }; + newNeighbors.insertSorted(newNeighbor, score); + } + } + } + } + private void addVectors(RandomAccessVectorValues vectorsToAdd) throws IOException { long start = System.nanoTime(), t = start; - // start at node 1! node 0 is added implicitly, in the constructor - for (int node = 1; node < vectorsToAdd.size(); node++) { + for (int node = 0; node < vectorsToAdd.size(); node++) { + if (initializedNodes.contains(node)) { + continue; + } addGraphNode(node, vectorsToAdd); if ((node % 10000 == 0) && infoStream.isEnabled(HNSW_COMPONENT)) { t = printGraphBuildStatus(node, start, t); @@ -167,6 +242,14 @@ public void addGraphNode(int node, T value) throws IOException { NeighborQueue candidates; final int nodeLevel = getRandomGraphLevel(ml, random); int curMaxLevel = hnsw.numLevels() - 1; + + // If entrynode is -1, then this should finish without adding neighbors + if (hnsw.entryNode() == -1) { + for (int level = nodeLevel; level >= 0; level--) { + hnsw.addNode(level, node); + } + return; + } int[] eps = new int[] {hnsw.entryNode()}; // if a node introduces new levels to the graph, add this new node on new levels diff --git a/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphSearcher.java b/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphSearcher.java index 13a338e0d183..4857d5b9d577 100644 --- a/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphSearcher.java +++ b/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphSearcher.java @@ -101,7 +101,12 @@ public static NeighborQueue search( new NeighborQueue(topK, true), new SparseFixedBitSet(vectors.size())); NeighborQueue results; - int[] eps = new int[] {graph.entryNode()}; + + int initialEp = graph.entryNode(); + if (initialEp == -1) { + return new NeighborQueue(1, true); + } + int[] eps = new int[] {initialEp}; int numVisited = 0; for (int level = graph.numLevels() - 1; level >= 1; level--) { results = graphSearcher.searchLevel(query, 1, level, eps, vectors, graph, null, visitedLimit); diff --git a/lucene/core/src/java/org/apache/lucene/util/hnsw/OnHeapHnswGraph.java b/lucene/core/src/java/org/apache/lucene/util/hnsw/OnHeapHnswGraph.java index 78137c2a6302..9862536de08c 100644 --- a/lucene/core/src/java/org/apache/lucene/util/hnsw/OnHeapHnswGraph.java +++ b/lucene/core/src/java/org/apache/lucene/util/hnsw/OnHeapHnswGraph.java @@ -20,10 +20,9 @@ import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS; import java.util.ArrayList; -import java.util.Arrays; import java.util.List; +import java.util.TreeMap; import org.apache.lucene.util.Accountable; -import org.apache.lucene.util.ArrayUtil; import org.apache.lucene.util.RamUsageEstimator; /** @@ -33,19 +32,20 @@ public final class OnHeapHnswGraph extends HnswGraph implements Accountable { private int numLevels; // the current number of levels in the graph - private int entryNode; // the current graph entry node on the top level + private int entryNode; // the current graph entry node on the top level. -1 if not set - // Nodes by level expressed as the level 0's nodes' ordinals. - // As level 0 contains all nodes, nodesByLevel.get(0) is null. - private final List nodesByLevel; - - // graph is a list of graph levels. - // Each level is represented as List – nodes' connections on this level. + // Level 0 is represented as List – nodes' connections on level 0. // Each entry in the list has the top maxConn/maxConn0 neighbors of a node. The nodes correspond // to vectors // added to HnswBuilder, and the node values are the ordinals of those vectors. // Thus, on all levels, neighbors expressed as the level 0's nodes' ordinals. - private final List> graph; + private final List graphLevel0; + // Represents levels 1-N. Each level is represented with a TreeMap that maps a levels level 0 + // ordinal to its neighbors on that level. All nodes are in level 0, so we do not need to maintain + // it in this list. However, to avoid changing list indexing, we always will make the first + // element + // null. + private final List> graphUpperLevels; private final int nsize; private final int nsize0; @@ -53,24 +53,17 @@ public final class OnHeapHnswGraph extends HnswGraph implements Accountable { private int upto; private NeighborArray cur; - OnHeapHnswGraph(int M, int levelOfFirstNode) { - this.numLevels = levelOfFirstNode + 1; - this.graph = new ArrayList<>(numLevels); - this.entryNode = 0; + OnHeapHnswGraph(int M) { + this.numLevels = 1; // Implicitly start the graph with a single level + this.graphLevel0 = new ArrayList<>(); + this.entryNode = -1; // Entry node should be negative until a node is added // Neighbours' size on upper levels (nsize) and level 0 (nsize0) // We allocate extra space for neighbours, but then prune them to keep allowed maximum this.nsize = M + 1; this.nsize0 = (M * 2 + 1); - for (int l = 0; l < numLevels; l++) { - graph.add(new ArrayList<>()); - graph.get(l).add(new NeighborArray(l == 0 ? nsize0 : nsize, true)); - } - this.nodesByLevel = new ArrayList<>(numLevels); - nodesByLevel.add(null); // we don't need this for 0th level, as it contains all nodes - for (int l = 1; l < numLevels; l++) { - nodesByLevel.add(new int[] {0}); - } + this.graphUpperLevels = new ArrayList<>(numLevels); + graphUpperLevels.add(null); // we don't need this for 0th level, as it contains all nodes } /** @@ -81,49 +74,52 @@ public final class OnHeapHnswGraph extends HnswGraph implements Accountable { */ public NeighborArray getNeighbors(int level, int node) { if (level == 0) { - return graph.get(level).get(node); + return graphLevel0.get(node); } - int nodeIndex = Arrays.binarySearch(nodesByLevel.get(level), 0, graph.get(level).size(), node); - assert nodeIndex >= 0; - return graph.get(level).get(nodeIndex); + TreeMap levelMap = graphUpperLevels.get(level); + assert levelMap.containsKey(node); + return levelMap.get(node); } @Override public int size() { - return graph.get(0).size(); // all nodes are located on the 0th level + return graphLevel0.size(); // all nodes are located on the 0th level } /** - * Add node on the given level + * Add node on the given level. Nodes can be inserted out of order, but it requires that the nodes + * preceded by the node inserted out of order are eventually added. * * @param level level to add a node on * @param node the node to add, represented as an ordinal on the level 0. */ public void addNode(int level, int node) { + if (entryNode == -1) { + entryNode = node; + } + if (level > 0) { // if the new node introduces a new level, add more levels to the graph, // and make this node the graph's new entry point if (level >= numLevels) { for (int i = numLevels; i <= level; i++) { - graph.add(new ArrayList<>()); - nodesByLevel.add(new int[] {node}); + graphUpperLevels.add(new TreeMap<>()); } numLevels = level + 1; entryNode = node; - } else { - // Add this node id to this level's nodes - int[] nodes = nodesByLevel.get(level); - int idx = graph.get(level).size(); - if (idx < nodes.length) { - nodes[idx] = node; - } else { - nodes = ArrayUtil.grow(nodes); - nodes[idx] = node; - nodesByLevel.set(level, nodes); - } + } + + graphUpperLevels.get(level).put(node, new NeighborArray(nsize, true)); + } else { + // Add nodes all the way up to and including "node" in the new graph on level 0. This will + // cause the size of the + // graph to differ from the number of nodes added to the graph. The size of the graph and the + // number of nodes + // added will only be in sync once all nodes from 0...last_node are added into the graph. + while (node >= graphLevel0.size()) { + graphLevel0.add(new NeighborArray(nsize0, true)); } } - graph.get(level).add(new NeighborArray(level == 0 ? nsize0 : nsize, true)); } @Override @@ -164,9 +160,9 @@ public int entryNode() { @Override public NodesIterator getNodesOnLevel(int level) { if (level == 0) { - return new NodesIterator(size()); + return new ArrayNodesIterator(size()); } else { - return new NodesIterator(nodesByLevel.get(level), graph.get(level).size()); + return new CollectionNodesIterator(graphUpperLevels.get(level).keySet()); } } @@ -184,19 +180,26 @@ public long ramBytesUsed() { + Integer.BYTES * 2; long total = 0; for (int l = 0; l < numLevels; l++) { - int numNodesOnLevel = graph.get(l).size(); if (l == 0) { total += - numNodesOnLevel * neighborArrayBytes0 + graphLevel0.size() * neighborArrayBytes0 + RamUsageEstimator.NUM_BYTES_OBJECT_REF; // for graph; } else { + long numNodesOnLevel = graphUpperLevels.get(l).size(); + + // For levels > 0, we represent the graph structure with a tree map. + // A single node in the tree contains 3 references (left root, right root, value) as well + // as an Integer for the key and 1 extra byte for the color of the node (this is actually 1 + // bit, but + // because we do not have that granularity, we set to 1 byte). In addition, we include 1 + // more reference for + // the tree map itself. total += - nodesByLevel.get(l).length * Integer.BYTES - + RamUsageEstimator.NUM_BYTES_ARRAY_HEADER - + RamUsageEstimator.NUM_BYTES_OBJECT_REF; // for nodesByLevel - total += - numNodesOnLevel * neighborArrayBytes - + RamUsageEstimator.NUM_BYTES_OBJECT_REF; // for graph; + numNodesOnLevel * (3L * RamUsageEstimator.NUM_BYTES_OBJECT_REF + Integer.BYTES + 1) + + RamUsageEstimator.NUM_BYTES_OBJECT_REF; + + // Add the size neighbor of each node + total += numNodesOnLevel * neighborArrayBytes; } } return total; diff --git a/lucene/core/src/test/org/apache/lucene/index/TestKnnGraph.java b/lucene/core/src/test/org/apache/lucene/index/TestKnnGraph.java index bba92fab2224..08f089430ba5 100644 --- a/lucene/core/src/test/org/apache/lucene/index/TestKnnGraph.java +++ b/lucene/core/src/test/org/apache/lucene/index/TestKnnGraph.java @@ -48,14 +48,12 @@ import org.apache.lucene.search.TopDocs; import org.apache.lucene.store.Directory; import org.apache.lucene.tests.util.LuceneTestCase; -import org.apache.lucene.util.ArrayUtil; import org.apache.lucene.util.Bits; import org.apache.lucene.util.BytesRef; import org.apache.lucene.util.IOUtils; import org.apache.lucene.util.VectorUtil; import org.apache.lucene.util.hnsw.HnswGraph; import org.apache.lucene.util.hnsw.HnswGraph.NodesIterator; -import org.apache.lucene.util.hnsw.HnswGraphBuilder; import org.junit.After; import org.junit.Before; @@ -179,21 +177,6 @@ public void testMerge() throws Exception { } } - /** - * Verify that we get the *same* graph by indexing one segment as we do by indexing two segments - * and merging. - */ - public void testMergeProducesSameGraph() throws Exception { - long seed = random().nextLong(); - int numDoc = atLeast(100); - int dimension = atLeast(10); - float[][] values = randomVectors(numDoc, dimension); - int mergePoint = random().nextInt(numDoc); - int[][][] mergedGraph = getIndexedGraph(values, mergePoint, seed); - int[][][] singleSegmentGraph = getIndexedGraph(values, -1, seed); - assertGraphEquals(singleSegmentGraph, mergedGraph); - } - /** Test writing and reading of multiple vector fields * */ public void testMultipleVectorFields() throws Exception { int numVectorFields = randomIntBetween(2, 5); @@ -227,52 +210,6 @@ public void testMultipleVectorFields() throws Exception { } } - private void assertGraphEquals(int[][][] expected, int[][][] actual) { - assertEquals("graph sizes differ", expected.length, actual.length); - for (int level = 0; level < expected.length; level++) { - for (int node = 0; node < expected[level].length; node++) { - assertArrayEquals("difference at ord=" + node, expected[level][node], actual[level][node]); - } - } - } - - /** - * Return a naive representation of an HNSW graph as a 3 dimensional array: 1st dim represents a - * graph layer. Each layer contains an array of arrays – a list of nodes and for each node a list - * of the node's neighbours. 2nd dim represents a node on a layer, and contains the node's - * neighbourhood, or {@code null} if a node is not present on this layer. 3rd dim represents - * neighbours of a node. - */ - private int[][][] getIndexedGraph(float[][] values, int mergePoint, long seed) - throws IOException { - HnswGraphBuilder.randSeed = seed; - int[][][] graph; - try (Directory dir = newDirectory()) { - IndexWriterConfig iwc = newIndexWriterConfig(); - iwc.setMergePolicy(new LogDocMergePolicy()); // for predictable segment ordering when merging - iwc.setCodec(codec); // don't use SimpleTextCodec - try (IndexWriter iw = new IndexWriter(dir, iwc)) { - for (int i = 0; i < values.length; i++) { - add(iw, i, values[i]); - if (i == mergePoint) { - // flush proactively to create a segment - iw.flush(); - } - } - iw.forceMerge(1); - } - try (IndexReader reader = DirectoryReader.open(dir)) { - PerFieldKnnVectorsFormat.FieldsReader perFieldReader = - (PerFieldKnnVectorsFormat.FieldsReader) - ((CodecReader) getOnlyLeafReader(reader)).getVectorReader(); - Lucene95HnswVectorsReader vectorReader = - (Lucene95HnswVectorsReader) perFieldReader.getFieldReader(KNN_GRAPH_FIELD); - graph = copyGraph(vectorReader.getGraph(KNN_GRAPH_FIELD)); - } - } - return graph; - } - private float[][] randomVectors(int numDoc, int dimension) { float[][] values = new float[numDoc][]; for (int i = 0; i < numDoc; i++) { @@ -297,27 +234,6 @@ private float[] randomVector(int dimension) { return value; } - int[][][] copyGraph(HnswGraph graphValues) throws IOException { - int[][][] graph = new int[graphValues.numLevels()][][]; - int size = graphValues.size(); - int[] scratch = new int[M * 2]; - - for (int level = 0; level < graphValues.numLevels(); level++) { - NodesIterator nodesItr = graphValues.getNodesOnLevel(level); - graph[level] = new int[size][]; - while (nodesItr.hasNext()) { - int node = nodesItr.nextInt(); - graphValues.seek(level, node); - int n, count = 0; - while ((n = graphValues.nextNeighbor()) != NO_MORE_DOCS) { - scratch[count++] = n; - } - graph[level][node] = ArrayUtil.copyOfSubArray(scratch, 0, count); - } - } - return graph; - } - /** Verify that searching does something reasonable */ public void testSearch() throws Exception { // We can't use dot product here since the vectors are laid out on a grid, not a sphere. diff --git a/lucene/core/src/test/org/apache/lucene/util/hnsw/HnswGraphTestCase.java b/lucene/core/src/test/org/apache/lucene/util/hnsw/HnswGraphTestCase.java index 10ce6b42a78d..80c9c7a93cf4 100644 --- a/lucene/core/src/test/org/apache/lucene/util/hnsw/HnswGraphTestCase.java +++ b/lucene/core/src/test/org/apache/lucene/util/hnsw/HnswGraphTestCase.java @@ -25,10 +25,14 @@ import java.io.IOException; import java.util.ArrayList; import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; import java.util.HashSet; import java.util.List; +import java.util.Map; import java.util.Random; import java.util.Set; +import java.util.stream.Collectors; import org.apache.lucene.codecs.KnnVectorsFormat; import org.apache.lucene.codecs.lucene95.Lucene95Codec; import org.apache.lucene.codecs.lucene95.Lucene95HnswVectorsFormat; @@ -84,6 +88,12 @@ abstract class HnswGraphTestCase extends LuceneTestCase { abstract AbstractMockVectorValues vectorValues(LeafReader reader, String fieldName) throws IOException; + abstract AbstractMockVectorValues vectorValues( + int size, + int dimension, + AbstractMockVectorValues pregeneratedVectorValues, + int pregeneratedOffset); + abstract Field knnVectorField(String name, T vector, VectorSimilarityFunction similarityFunction); abstract RandomAccessVectorValues circularVectorValues(int nDoc); @@ -427,6 +437,238 @@ public void testSearchWithSelectiveAcceptOrds() throws IOException { } } + public void testBuildOnHeapHnswGraphOutOfOrder() throws IOException { + int maxNumLevels = randomIntBetween(2, 10); + int nodeCount = randomIntBetween(1, 100); + + List> nodesPerLevel = new ArrayList<>(); + for (int i = 0; i < maxNumLevels; i++) { + nodesPerLevel.add(new ArrayList<>()); + } + + int numLevels = 0; + for (int currNode = 0; currNode < nodeCount; currNode++) { + int nodeMaxLevel = random().nextInt(1, maxNumLevels + 1); + numLevels = Math.max(numLevels, nodeMaxLevel); + for (int currLevel = 0; currLevel < nodeMaxLevel; currLevel++) { + nodesPerLevel.get(currLevel).add(currNode); + } + } + + OnHeapHnswGraph topDownOrderReversedHnsw = new OnHeapHnswGraph(10); + for (int currLevel = numLevels - 1; currLevel >= 0; currLevel--) { + List currLevelNodes = nodesPerLevel.get(currLevel); + int currLevelNodesSize = currLevelNodes.size(); + for (int currNodeInd = currLevelNodesSize - 1; currNodeInd >= 0; currNodeInd--) { + topDownOrderReversedHnsw.addNode(currLevel, currLevelNodes.get(currNodeInd)); + } + } + + OnHeapHnswGraph bottomUpOrderReversedHnsw = new OnHeapHnswGraph(10); + for (int currLevel = 0; currLevel < numLevels; currLevel++) { + List currLevelNodes = nodesPerLevel.get(currLevel); + int currLevelNodesSize = currLevelNodes.size(); + for (int currNodeInd = currLevelNodesSize - 1; currNodeInd >= 0; currNodeInd--) { + bottomUpOrderReversedHnsw.addNode(currLevel, currLevelNodes.get(currNodeInd)); + } + } + + OnHeapHnswGraph topDownOrderRandomHnsw = new OnHeapHnswGraph(10); + for (int currLevel = numLevels - 1; currLevel >= 0; currLevel--) { + List currLevelNodes = new ArrayList<>(nodesPerLevel.get(currLevel)); + Collections.shuffle(currLevelNodes, random()); + for (Integer currNode : currLevelNodes) { + topDownOrderRandomHnsw.addNode(currLevel, currNode); + } + } + + OnHeapHnswGraph bottomUpExpectedHnsw = new OnHeapHnswGraph(10); + for (int currLevel = 0; currLevel < numLevels; currLevel++) { + for (Integer currNode : nodesPerLevel.get(currLevel)) { + bottomUpExpectedHnsw.addNode(currLevel, currNode); + } + } + + assertEquals(nodeCount, bottomUpExpectedHnsw.getNodesOnLevel(0).size()); + for (Integer node : nodesPerLevel.get(0)) { + assertEquals(0, bottomUpExpectedHnsw.getNeighbors(0, node).size()); + } + + for (int currLevel = 1; currLevel < numLevels; currLevel++) { + NodesIterator nodesIterator = bottomUpExpectedHnsw.getNodesOnLevel(currLevel); + List expectedNodesOnLevel = nodesPerLevel.get(currLevel); + assertEquals(expectedNodesOnLevel.size(), nodesIterator.size()); + for (Integer expectedNode : expectedNodesOnLevel) { + int currentNode = nodesIterator.nextInt(); + assertEquals(expectedNode.intValue(), currentNode); + assertEquals(0, bottomUpExpectedHnsw.getNeighbors(currLevel, currentNode).size()); + } + } + + assertGraphEqual(bottomUpExpectedHnsw, topDownOrderReversedHnsw); + assertGraphEqual(bottomUpExpectedHnsw, bottomUpOrderReversedHnsw); + assertGraphEqual(bottomUpExpectedHnsw, topDownOrderRandomHnsw); + } + + public void testHnswGraphBuilderInitializationFromGraph_withOffsetZero() throws IOException { + int totalSize = atLeast(100); + int initializerSize = random().nextInt(5, totalSize); + int docIdOffset = 0; + int dim = atLeast(10); + long seed = random().nextLong(); + + AbstractMockVectorValues initializerVectors = vectorValues(initializerSize, dim); + HnswGraphBuilder initializerBuilder = + HnswGraphBuilder.create( + initializerVectors, getVectorEncoding(), similarityFunction, 10, 30, seed); + + OnHeapHnswGraph initializerGraph = initializerBuilder.build(initializerVectors.copy()); + AbstractMockVectorValues finalVectorValues = + vectorValues(totalSize, dim, initializerVectors, docIdOffset); + + Map initializerOrdMap = + createOffsetOrdinalMap(initializerSize, finalVectorValues, docIdOffset); + + HnswGraphBuilder finalBuilder = + HnswGraphBuilder.create( + finalVectorValues, + getVectorEncoding(), + similarityFunction, + 10, + 30, + seed, + initializerGraph, + initializerOrdMap); + + // When offset is 0, the graphs should be identical before vectors are added + assertGraphEqual(initializerGraph, finalBuilder.getGraph()); + + OnHeapHnswGraph finalGraph = finalBuilder.build(finalVectorValues.copy()); + assertGraphContainsGraph(finalGraph, initializerGraph, initializerOrdMap); + } + + public void testHnswGraphBuilderInitializationFromGraph_withNonZeroOffset() throws IOException { + int totalSize = atLeast(100); + int initializerSize = random().nextInt(5, totalSize); + int docIdOffset = random().nextInt(1, totalSize - initializerSize + 1); + int dim = atLeast(10); + long seed = random().nextLong(); + + AbstractMockVectorValues initializerVectors = vectorValues(initializerSize, dim); + HnswGraphBuilder initializerBuilder = + HnswGraphBuilder.create( + initializerVectors.copy(), getVectorEncoding(), similarityFunction, 10, 30, seed); + OnHeapHnswGraph initializerGraph = initializerBuilder.build(initializerVectors.copy()); + AbstractMockVectorValues finalVectorValues = + vectorValues(totalSize, dim, initializerVectors.copy(), docIdOffset); + Map initializerOrdMap = + createOffsetOrdinalMap(initializerSize, finalVectorValues.copy(), docIdOffset); + + HnswGraphBuilder finalBuilder = + HnswGraphBuilder.create( + finalVectorValues, + getVectorEncoding(), + similarityFunction, + 10, + 30, + seed, + initializerGraph, + initializerOrdMap); + + assertGraphInitializedFromGraph(finalBuilder.getGraph(), initializerGraph, initializerOrdMap); + + // Confirm that the graph is appropriately constructed by checking that the nodes in the old + // graph are present in the levels of the new graph + OnHeapHnswGraph finalGraph = finalBuilder.build(finalVectorValues.copy()); + assertGraphContainsGraph(finalGraph, initializerGraph, initializerOrdMap); + } + + private void assertGraphContainsGraph( + HnswGraph g, HnswGraph h, Map oldToNewOrdMap) throws IOException { + for (int i = 0; i < h.numLevels(); i++) { + int[] finalGraphNodesOnLevel = nodesIteratorToArray(g.getNodesOnLevel(i)); + int[] initializerGraphNodesOnLevel = + mapArrayAndSort(nodesIteratorToArray(h.getNodesOnLevel(i)), oldToNewOrdMap); + int overlap = computeOverlap(finalGraphNodesOnLevel, initializerGraphNodesOnLevel); + assertEquals(initializerGraphNodesOnLevel.length, overlap); + } + } + + private void assertGraphInitializedFromGraph( + HnswGraph g, HnswGraph h, Map oldToNewOrdMap) throws IOException { + assertEquals("the number of levels in the graphs are different!", g.numLevels(), h.numLevels()); + // Confirm that the size of the new graph includes all nodes up to an including the max new + // ordinal in the old to + // new ordinal mapping + assertEquals( + "the number of nodes in the graphs are different!", + g.size(), + Collections.max(oldToNewOrdMap.values()) + 1); + + // assert the nodes from the previous graph are successfully to levels > 0 in the new graph + for (int level = 1; level < g.numLevels(); level++) { + NodesIterator nodesOnLevel = g.getNodesOnLevel(level); + NodesIterator nodesOnLevel2 = h.getNodesOnLevel(level); + while (nodesOnLevel.hasNext() && nodesOnLevel2.hasNext()) { + int node = nodesOnLevel.nextInt(); + int node2 = oldToNewOrdMap.get(nodesOnLevel2.nextInt()); + assertEquals("nodes in the graphs are different", node, node2); + } + } + + // assert that the neighbors from the old graph are successfully transferred to the new graph + for (int level = 0; level < g.numLevels(); level++) { + NodesIterator nodesOnLevel = h.getNodesOnLevel(level); + while (nodesOnLevel.hasNext()) { + int node = nodesOnLevel.nextInt(); + g.seek(level, oldToNewOrdMap.get(node)); + h.seek(level, node); + assertEquals( + "arcs differ for node " + node, + getNeighborNodes(g), + getNeighborNodes(h).stream().map(oldToNewOrdMap::get).collect(Collectors.toSet())); + } + } + } + + private Map createOffsetOrdinalMap( + int docIdSize, AbstractMockVectorValues totalVectorValues, int docIdOffset) { + // Compute the offset for the ordinal map to be the number of non-null vectors in the total + // vector values + // before the docIdOffset + int ordinalOffset = 0; + while (totalVectorValues.nextDoc() < docIdOffset) { + ordinalOffset++; + } + + Map offsetOrdinalMap = new HashMap<>(); + for (int curr = 0; + totalVectorValues.docID() < docIdOffset + docIdSize; + totalVectorValues.nextDoc()) { + offsetOrdinalMap.put(curr, ordinalOffset + curr++); + } + + return offsetOrdinalMap; + } + + private int[] nodesIteratorToArray(NodesIterator nodesIterator) { + int[] arr = new int[nodesIterator.size()]; + int i = 0; + while (nodesIterator.hasNext()) { + arr[i++] = nodesIterator.nextInt(); + } + return arr; + } + + private int[] mapArrayAndSort(int[] arr, Map map) { + int[] mappedA = new int[arr.length]; + for (int i = 0; i < arr.length; i++) { + mappedA[i] = map.get(arr[i]); + } + Arrays.sort(mappedA); + return mappedA; + } + @SuppressWarnings("unchecked") public void testVisitedLimit() throws IOException { int nDoc = 500; @@ -531,8 +773,8 @@ public void testDiversity() throws IOException { HnswGraphBuilder.create( vectors, getVectorEncoding(), similarityFunction, 2, 10, random().nextInt()); // node 0 is added by the builder constructor - // builder.addGraphNode(vectors.vectorValue(0)); RandomAccessVectorValues vectorsCopy = vectors.copy(); + builder.addGraphNode(0, vectorsCopy); builder.addGraphNode(1, vectorsCopy); builder.addGraphNode(2, vectorsCopy); // now every node has tried to attach every other node as a neighbor, but @@ -586,9 +828,8 @@ public void testDiversityFallback() throws IOException { HnswGraphBuilder builder = HnswGraphBuilder.create( vectors, getVectorEncoding(), similarityFunction, 1, 10, random().nextInt()); - // node 0 is added by the builder constructor - // builder.addGraphNode(vectors.vectorValue(0)); RandomAccessVectorValues vectorsCopy = vectors.copy(); + builder.addGraphNode(0, vectorsCopy); builder.addGraphNode(1, vectorsCopy); builder.addGraphNode(2, vectorsCopy); assertLevel0Neighbors(builder.hnsw, 0, 1, 2); @@ -619,9 +860,8 @@ public void testDiversity3d() throws IOException { HnswGraphBuilder builder = HnswGraphBuilder.create( vectors, getVectorEncoding(), similarityFunction, 1, 10, random().nextInt()); - // node 0 is added by the builder constructor - // builder.addGraphNode(vectors.vectorValue(0)); RandomAccessVectorValues vectorsCopy = vectors.copy(); + builder.addGraphNode(0, vectorsCopy); builder.addGraphNode(1, vectorsCopy); builder.addGraphNode(2, vectorsCopy); assertLevel0Neighbors(builder.hnsw, 0, 1, 2); diff --git a/lucene/core/src/test/org/apache/lucene/util/hnsw/TestHnswByteVectorGraph.java b/lucene/core/src/test/org/apache/lucene/util/hnsw/TestHnswByteVectorGraph.java index 258864ade7dd..3a2d92ff92a6 100644 --- a/lucene/core/src/test/org/apache/lucene/util/hnsw/TestHnswByteVectorGraph.java +++ b/lucene/core/src/test/org/apache/lucene/util/hnsw/TestHnswByteVectorGraph.java @@ -85,6 +85,34 @@ AbstractMockVectorValues vectorValues(float[][] values) { return MockByteVectorValues.fromValues(bValues); } + @Override + AbstractMockVectorValues vectorValues( + int size, + int dimension, + AbstractMockVectorValues pregeneratedVectorValues, + int pregeneratedOffset) { + byte[][] vectors = new byte[size][]; + byte[][] randomVectors = + createRandomByteVectors(size - pregeneratedVectorValues.values.length, dimension, random()); + + for (int i = 0; i < pregeneratedOffset; i++) { + vectors[i] = randomVectors[i]; + } + + int currentDoc; + while ((currentDoc = pregeneratedVectorValues.nextDoc()) != NO_MORE_DOCS) { + vectors[pregeneratedOffset + currentDoc] = pregeneratedVectorValues.values[currentDoc]; + } + + for (int i = pregeneratedOffset + pregeneratedVectorValues.values.length; + i < vectors.length; + i++) { + vectors[i] = randomVectors[i - pregeneratedVectorValues.values.length]; + } + + return MockByteVectorValues.fromValues(vectors); + } + @Override AbstractMockVectorValues vectorValues(LeafReader reader, String fieldName) throws IOException { diff --git a/lucene/core/src/test/org/apache/lucene/util/hnsw/TestHnswFloatVectorGraph.java b/lucene/core/src/test/org/apache/lucene/util/hnsw/TestHnswFloatVectorGraph.java index 16f2e7330e27..5dda5bf0a838 100644 --- a/lucene/core/src/test/org/apache/lucene/util/hnsw/TestHnswFloatVectorGraph.java +++ b/lucene/core/src/test/org/apache/lucene/util/hnsw/TestHnswFloatVectorGraph.java @@ -79,6 +79,35 @@ AbstractMockVectorValues vectorValues(LeafReader reader, String fieldNa return MockVectorValues.fromValues(vectors); } + @Override + AbstractMockVectorValues vectorValues( + int size, + int dimension, + AbstractMockVectorValues pregeneratedVectorValues, + int pregeneratedOffset) { + float[][] vectors = new float[size][]; + float[][] randomVectors = + createRandomFloatVectors( + size - pregeneratedVectorValues.values.length, dimension, random()); + + for (int i = 0; i < pregeneratedOffset; i++) { + vectors[i] = randomVectors[i]; + } + + int currentDoc; + while ((currentDoc = pregeneratedVectorValues.nextDoc()) != NO_MORE_DOCS) { + vectors[pregeneratedOffset + currentDoc] = pregeneratedVectorValues.values[currentDoc]; + } + + for (int i = pregeneratedOffset + pregeneratedVectorValues.values.length; + i < vectors.length; + i++) { + vectors[i] = randomVectors[i - pregeneratedVectorValues.values.length]; + } + + return MockVectorValues.fromValues(vectors); + } + @Override Field knnVectorField(String name, float[] vector, VectorSimilarityFunction similarityFunction) { return new KnnFloatVectorField(name, vector, similarityFunction);