Skip to content

Commit

Permalink
Reuse HNSW graph for intialization during merge (#12050)
Browse files Browse the repository at this point in the history
* Remove implicit addition of vector 0

Removes logic to add 0 vector implicitly. This is in preparation for
adding nodes from other graphs to initialize a new graph. Having the
implicit addition of node 0 complicates this logic.

Signed-off-by: John Mazanec <jmazane@amazon.com>

* Enable out of order insertion of nodes in hnsw

Enables nodes to be added into OnHeapHnswGraph in out of order fashion.
To do so, additional operations have to be taken to resort the
nodesByLevel array. Optimizations have been made to avoid sorting
whenever possible.

Signed-off-by: John Mazanec <jmazane@amazon.com>

* Add ability to initialize from graph

Adds method to initialize an HNSWGraphBuilder from another HNSWGraph.
Initialization can only happen when the builder's graph is empty.

Signed-off-by: John Mazanec <jmazane@amazon.com>

* Utilize merge with graph init in HNSWWriter

Uses HNSWGraphBuilder initialization from graph functionality in
Lucene95HnswVectorsWriter. Selects the largest graph to initialize the
new graph produced by the HNSWGraphBuilder for merge.

Signed-off-by: John Mazanec <jmazane@amazon.com>

* Minor modifications to Lucene95HnswVectorsWriter

Signed-off-by: John Mazanec <jmazane@amazon.com>

* Use TreeMap for graph structure for levels > 0

Refactors OnHeapHnswGraph to use TreeMap to represent graph structure of
levels greater than 0. Refactors NodesIterator to support set
representation of nodes.

Signed-off-by: John Mazanec <jmazane@amazon.com>

* Refactor initializer to be in static create method

Refeactors initialization from graph to be accessible via a create
static method in HnswGraphBuilder.

Signed-off-by: John Mazanec <jmazane@amazon.com>

* Address review comments

Signed-off-by: John Mazanec <jmazane@amazon.com>

* Add change log entry

Signed-off-by: John Mazanec <jmazane@amazon.com>

* Remove empty iterator for neighborqueue

Signed-off-by: John Mazanec <jmazane@amazon.com>

---------

Signed-off-by: John Mazanec <jmazane@amazon.com>
  • Loading branch information
jmazanec15 authored Feb 7, 2023
1 parent ab074d5 commit 776149f
Show file tree
Hide file tree
Showing 16 changed files with 729 additions and 200 deletions.
2 changes: 2 additions & 0 deletions lucene/CHANGES.txt
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,8 @@ Optimizations

* GITHUB#12128, GITHUB#12133: Speed up docvalues set query by making use of sortedness. (Robert Muir, Uwe Schindler)

* GITHUB#12050: Reuse HNSW graph for intialization during merge (Jack Mazanec)

Bug Fixes
---------------------
(No changes)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -561,9 +561,9 @@ public int entryNode() {
@Override
public NodesIterator getNodesOnLevel(int level) {
if (level == 0) {
return new NodesIterator(size());
return new ArrayNodesIterator(size());
} else {
return new NodesIterator(nodesByLevel[level], nodesByLevel[level].length);
return new ArrayNodesIterator(nodesByLevel[level], nodesByLevel[level].length);
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -163,9 +163,9 @@ public int entryNode() {
@Override
public NodesIterator getNodesOnLevel(int level) {
if (level == 0) {
return new NodesIterator(size());
return new ArrayNodesIterator(size());
} else {
return new NodesIterator(nodesByLevel.get(level), graph.get(level).size());
return new ArrayNodesIterator(nodesByLevel.get(level), graph.get(level).size());
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -457,9 +457,9 @@ public int entryNode() {
@Override
public NodesIterator getNodesOnLevel(int level) {
if (level == 0) {
return new NodesIterator(size());
return new ArrayNodesIterator(size());
} else {
return new NodesIterator(nodesByLevel[level], nodesByLevel[level].length);
return new ArrayNodesIterator(nodesByLevel[level], nodesByLevel[level].length);
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -533,9 +533,9 @@ public int entryNode() {
@Override
public NodesIterator getNodesOnLevel(int level) {
if (level == 0) {
return new NodesIterator(size());
return new ArrayNodesIterator(size());
} else {
return new NodesIterator(nodesByLevel[level], nodesByLevel[level].length);
return new ArrayNodesIterator(nodesByLevel[level], nodesByLevel[level].length);
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -345,7 +345,7 @@ public NodesIterator getNodesOnLevel(int level) {
if (level == 0) {
return graph.getNodesOnLevel(0);
} else {
return new NodesIterator(nodesByLevel.get(level), nodesByLevel.get(level).length);
return new ArrayNodesIterator(nodesByLevel.get(level), nodesByLevel.get(level).length);
}
}
};
Expand Down Expand Up @@ -687,10 +687,7 @@ public void addValue(int docID, Object value) throws IOException {
assert docID > lastDocID;
docsWithField.add(docID);
vectors.add(copyValue(vectorValue));
if (node > 0) {
// start at node 1! node 0 is added implicitly, in the constructor
hnswGraphBuilder.addGraphNode(node, vectorValue);
}
hnswGraphBuilder.addGraphNode(node, vectorValue);
node++;
lastDocID = docID;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -573,9 +573,9 @@ public int entryNode() throws IOException {
@Override
public NodesIterator getNodesOnLevel(int level) {
if (level == 0) {
return new NodesIterator(size());
return new ArrayNodesIterator(size());
} else {
return new NodesIterator(nodesByLevel[level], nodesByLevel[level].length);
return new ArrayNodesIterator(nodesByLevel[level], nodesByLevel[level].length);
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,16 @@
import java.nio.ByteOrder;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import org.apache.lucene.codecs.CodecUtil;
import org.apache.lucene.codecs.KnnFieldVectorsWriter;
import org.apache.lucene.codecs.KnnVectorsReader;
import org.apache.lucene.codecs.KnnVectorsWriter;
import org.apache.lucene.codecs.lucene90.IndexedDISI;
import org.apache.lucene.codecs.perfield.PerFieldKnnVectorsFormat;
import org.apache.lucene.index.*;
import org.apache.lucene.index.Sorter;
import org.apache.lucene.search.DocIdSetIterator;
Expand Down Expand Up @@ -357,7 +362,7 @@ public NodesIterator getNodesOnLevel(int level) {
if (level == 0) {
return graph.getNodesOnLevel(0);
} else {
return new NodesIterator(nodesByLevel.get(level), nodesByLevel.get(level).length);
return new ArrayNodesIterator(nodesByLevel.get(level), nodesByLevel.get(level).length);
}
}
};
Expand Down Expand Up @@ -424,6 +429,7 @@ public void mergeOneField(FieldInfo fieldInfo, MergeState mergeState) throws IOE
int[][] vectorIndexNodeOffsets = null;
if (docsWithField.cardinality() != 0) {
// build graph
int initializerIndex = selectGraphForInitialization(mergeState, fieldInfo);
graph =
switch (fieldInfo.getVectorEncoding()) {
case BYTE -> {
Expand All @@ -434,13 +440,7 @@ public void mergeOneField(FieldInfo fieldInfo, MergeState mergeState) throws IOE
vectorDataInput,
byteSize);
HnswGraphBuilder<byte[]> hnswGraphBuilder =
HnswGraphBuilder.create(
vectorValues,
fieldInfo.getVectorEncoding(),
fieldInfo.getVectorSimilarityFunction(),
M,
beamWidth,
HnswGraphBuilder.randSeed);
createHnswGraphBuilder(mergeState, fieldInfo, vectorValues, initializerIndex);
hnswGraphBuilder.setInfoStream(segmentWriteState.infoStream);
yield hnswGraphBuilder.build(vectorValues.copy());
}
Expand All @@ -452,13 +452,7 @@ public void mergeOneField(FieldInfo fieldInfo, MergeState mergeState) throws IOE
vectorDataInput,
byteSize);
HnswGraphBuilder<float[]> hnswGraphBuilder =
HnswGraphBuilder.create(
vectorValues,
fieldInfo.getVectorEncoding(),
fieldInfo.getVectorSimilarityFunction(),
M,
beamWidth,
HnswGraphBuilder.randSeed);
createHnswGraphBuilder(mergeState, fieldInfo, vectorValues, initializerIndex);
hnswGraphBuilder.setInfoStream(segmentWriteState.infoStream);
yield hnswGraphBuilder.build(vectorValues.copy());
}
Expand Down Expand Up @@ -489,6 +483,189 @@ public void mergeOneField(FieldInfo fieldInfo, MergeState mergeState) throws IOE
}
}

private <T> HnswGraphBuilder<T> createHnswGraphBuilder(
MergeState mergeState,
FieldInfo fieldInfo,
RandomAccessVectorValues<T> floatVectorValues,
int initializerIndex)
throws IOException {
if (initializerIndex == -1) {
return HnswGraphBuilder.create(
floatVectorValues,
fieldInfo.getVectorEncoding(),
fieldInfo.getVectorSimilarityFunction(),
M,
beamWidth,
HnswGraphBuilder.randSeed);
}

HnswGraph initializerGraph =
getHnswGraphFromReader(fieldInfo.name, mergeState.knnVectorsReaders[initializerIndex]);
Map<Integer, Integer> ordinalMapper =
getOldToNewOrdinalMap(mergeState, fieldInfo, initializerIndex);
return HnswGraphBuilder.create(
floatVectorValues,
fieldInfo.getVectorEncoding(),
fieldInfo.getVectorSimilarityFunction(),
M,
beamWidth,
HnswGraphBuilder.randSeed,
initializerGraph,
ordinalMapper);
}

private int selectGraphForInitialization(MergeState mergeState, FieldInfo fieldInfo)
throws IOException {
// Find the KnnVectorReader with the most docs that meets the following criteria:
// 1. Does not contain any deleted docs
// 2. Is a Lucene95HnswVectorsReader/PerFieldKnnVectorReader
// If no readers exist that meet this criteria, return -1. If they do, return their index in
// merge state
int maxCandidateVectorCount = 0;
int initializerIndex = -1;

for (int i = 0; i < mergeState.liveDocs.length; i++) {
KnnVectorsReader currKnnVectorsReader = mergeState.knnVectorsReaders[i];
if (mergeState.knnVectorsReaders[i]
instanceof PerFieldKnnVectorsFormat.FieldsReader candidateReader) {
currKnnVectorsReader = candidateReader.getFieldReader(fieldInfo.name);
}

if (!allMatch(mergeState.liveDocs[i])
|| !(currKnnVectorsReader instanceof Lucene95HnswVectorsReader candidateReader)) {
continue;
}

int candidateVectorCount = 0;
switch (fieldInfo.getVectorEncoding()) {
case BYTE -> {
ByteVectorValues byteVectorValues = candidateReader.getByteVectorValues(fieldInfo.name);
if (byteVectorValues == null) {
continue;
}
candidateVectorCount = byteVectorValues.size();
}
case FLOAT32 -> {
FloatVectorValues vectorValues = candidateReader.getFloatVectorValues(fieldInfo.name);
if (vectorValues == null) {
continue;
}
candidateVectorCount = vectorValues.size();
}
}

if (candidateVectorCount > maxCandidateVectorCount) {
maxCandidateVectorCount = candidateVectorCount;
initializerIndex = i;
}
}
return initializerIndex;
}

private HnswGraph getHnswGraphFromReader(String fieldName, KnnVectorsReader knnVectorsReader)
throws IOException {
if (knnVectorsReader instanceof PerFieldKnnVectorsFormat.FieldsReader perFieldReader
&& perFieldReader.getFieldReader(fieldName)
instanceof Lucene95HnswVectorsReader fieldReader) {
return fieldReader.getGraph(fieldName);
}

if (knnVectorsReader instanceof Lucene95HnswVectorsReader) {
return ((Lucene95HnswVectorsReader) knnVectorsReader).getGraph(fieldName);
}

// We should not reach here because knnVectorsReader's type is checked in
// selectGraphForInitialization
throw new IllegalArgumentException(
"Invalid KnnVectorsReader type for field: "
+ fieldName
+ ". Must be Lucene95HnswVectorsReader or newer");
}

private Map<Integer, Integer> getOldToNewOrdinalMap(
MergeState mergeState, FieldInfo fieldInfo, int initializerIndex) throws IOException {

DocIdSetIterator initializerIterator = null;

switch (fieldInfo.getVectorEncoding()) {
case BYTE -> initializerIterator =
mergeState.knnVectorsReaders[initializerIndex].getByteVectorValues(fieldInfo.name);
case FLOAT32 -> initializerIterator =
mergeState.knnVectorsReaders[initializerIndex].getFloatVectorValues(fieldInfo.name);
}

MergeState.DocMap initializerDocMap = mergeState.docMaps[initializerIndex];

Map<Integer, Integer> newIdToOldOrdinal = new HashMap<>();
int oldOrd = 0;
int maxNewDocID = -1;
for (int oldId = initializerIterator.nextDoc();
oldId != NO_MORE_DOCS;
oldId = initializerIterator.nextDoc()) {
if (isCurrentVectorNull(initializerIterator)) {
continue;
}
int newId = initializerDocMap.get(oldId);
maxNewDocID = Math.max(newId, maxNewDocID);
newIdToOldOrdinal.put(newId, oldOrd);
oldOrd++;
}

if (maxNewDocID == -1) {
return Collections.emptyMap();
}

Map<Integer, Integer> oldToNewOrdinalMap = new HashMap<>();

DocIdSetIterator vectorIterator = null;
switch (fieldInfo.getVectorEncoding()) {
case BYTE -> vectorIterator = MergedVectorValues.mergeByteVectorValues(fieldInfo, mergeState);
case FLOAT32 -> vectorIterator =
MergedVectorValues.mergeFloatVectorValues(fieldInfo, mergeState);
}

int newOrd = 0;
for (int newDocId = vectorIterator.nextDoc();
newDocId <= maxNewDocID;
newDocId = vectorIterator.nextDoc()) {
if (isCurrentVectorNull(vectorIterator)) {
continue;
}

if (newIdToOldOrdinal.containsKey(newDocId)) {
oldToNewOrdinalMap.put(newIdToOldOrdinal.get(newDocId), newOrd);
}
newOrd++;
}

return oldToNewOrdinalMap;
}

private boolean isCurrentVectorNull(DocIdSetIterator docIdSetIterator) throws IOException {
if (docIdSetIterator instanceof FloatVectorValues) {
return ((FloatVectorValues) docIdSetIterator).vectorValue() == null;
}

if (docIdSetIterator instanceof ByteVectorValues) {
return ((ByteVectorValues) docIdSetIterator).vectorValue() == null;
}

return true;
}

private boolean allMatch(Bits bits) {
if (bits == null) {
return true;
}

for (int i = 0; i < bits.length(); i++) {
if (!bits.get(i)) {
return false;
}
}
return true;
}

/**
* @param graph Write the graph in a compressed format
* @return The non-cumulative offsets for the nodes. Should be used to create cumulative offsets.
Expand Down Expand Up @@ -735,10 +912,7 @@ public void addValue(int docID, T vectorValue) throws IOException {
assert docID > lastDocID;
docsWithField.add(docID);
vectors.add(copyValue(vectorValue));
if (node > 0) {
// start at node 1! node 0 is added implicitly, in the constructor
hnswGraphBuilder.addGraphNode(node, vectorValue);
}
hnswGraphBuilder.addGraphNode(node, vectorValue);
node++;
lastDocID = docID;
}
Expand Down
Loading

0 comments on commit 776149f

Please sign in to comment.