Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reuse HNSW graph for intialization during merge #12050

Merged
merged 10 commits into from
Feb 7, 2023
2 changes: 2 additions & 0 deletions lucene/CHANGES.txt
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,8 @@ Optimizations
in order to achieve the same false positive probability with less memory.
(Jean-François Boeuf)

* GITHUB#12050: Reuse HNSW graph for intialization during merge (Jack Mazanec)

Bug Fixes
---------------------
(No changes)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -561,9 +561,9 @@ public int entryNode() {
@Override
public NodesIterator getNodesOnLevel(int level) {
if (level == 0) {
return new NodesIterator(size());
return new ArrayNodesIterator(size());
} else {
return new NodesIterator(nodesByLevel[level], nodesByLevel[level].length);
return new ArrayNodesIterator(nodesByLevel[level], nodesByLevel[level].length);
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -163,9 +163,9 @@ public int entryNode() {
@Override
public NodesIterator getNodesOnLevel(int level) {
if (level == 0) {
return new NodesIterator(size());
return new ArrayNodesIterator(size());
} else {
return new NodesIterator(nodesByLevel.get(level), graph.get(level).size());
return new ArrayNodesIterator(nodesByLevel.get(level), graph.get(level).size());
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -457,9 +457,9 @@ public int entryNode() {
@Override
public NodesIterator getNodesOnLevel(int level) {
if (level == 0) {
return new NodesIterator(size());
return new ArrayNodesIterator(size());
} else {
return new NodesIterator(nodesByLevel[level], nodesByLevel[level].length);
return new ArrayNodesIterator(nodesByLevel[level], nodesByLevel[level].length);
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -533,9 +533,9 @@ public int entryNode() {
@Override
public NodesIterator getNodesOnLevel(int level) {
if (level == 0) {
return new NodesIterator(size());
return new ArrayNodesIterator(size());
} else {
return new NodesIterator(nodesByLevel[level], nodesByLevel[level].length);
return new ArrayNodesIterator(nodesByLevel[level], nodesByLevel[level].length);
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -345,7 +345,7 @@ public NodesIterator getNodesOnLevel(int level) {
if (level == 0) {
return graph.getNodesOnLevel(0);
} else {
return new NodesIterator(nodesByLevel.get(level), nodesByLevel.get(level).length);
return new ArrayNodesIterator(nodesByLevel.get(level), nodesByLevel.get(level).length);
}
}
};
Expand Down Expand Up @@ -687,10 +687,7 @@ public void addValue(int docID, Object value) throws IOException {
assert docID > lastDocID;
docsWithField.add(docID);
vectors.add(copyValue(vectorValue));
if (node > 0) {
// start at node 1! node 0 is added implicitly, in the constructor
hnswGraphBuilder.addGraphNode(node, vectorValue);
}
hnswGraphBuilder.addGraphNode(node, vectorValue);
node++;
lastDocID = docID;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -573,9 +573,9 @@ public int entryNode() throws IOException {
@Override
public NodesIterator getNodesOnLevel(int level) {
if (level == 0) {
return new NodesIterator(size());
return new ArrayNodesIterator(size());
} else {
return new NodesIterator(nodesByLevel[level], nodesByLevel[level].length);
return new ArrayNodesIterator(nodesByLevel[level], nodesByLevel[level].length);
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,16 @@
import java.nio.ByteOrder;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import org.apache.lucene.codecs.CodecUtil;
import org.apache.lucene.codecs.KnnFieldVectorsWriter;
import org.apache.lucene.codecs.KnnVectorsReader;
import org.apache.lucene.codecs.KnnVectorsWriter;
import org.apache.lucene.codecs.lucene90.IndexedDISI;
import org.apache.lucene.codecs.perfield.PerFieldKnnVectorsFormat;
import org.apache.lucene.index.*;
import org.apache.lucene.index.Sorter;
import org.apache.lucene.search.DocIdSetIterator;
Expand Down Expand Up @@ -357,7 +362,7 @@ public NodesIterator getNodesOnLevel(int level) {
if (level == 0) {
return graph.getNodesOnLevel(0);
} else {
return new NodesIterator(nodesByLevel.get(level), nodesByLevel.get(level).length);
return new ArrayNodesIterator(nodesByLevel.get(level), nodesByLevel.get(level).length);
}
}
};
Expand Down Expand Up @@ -424,6 +429,7 @@ public void mergeOneField(FieldInfo fieldInfo, MergeState mergeState) throws IOE
int[][] vectorIndexNodeOffsets = null;
if (docsWithField.cardinality() != 0) {
// build graph
int initializerIndex = selectGraphForInitialization(mergeState, fieldInfo);
graph =
switch (fieldInfo.getVectorEncoding()) {
case BYTE -> {
Expand All @@ -434,13 +440,7 @@ public void mergeOneField(FieldInfo fieldInfo, MergeState mergeState) throws IOE
vectorDataInput,
byteSize);
HnswGraphBuilder<byte[]> hnswGraphBuilder =
HnswGraphBuilder.create(
vectorValues,
fieldInfo.getVectorEncoding(),
fieldInfo.getVectorSimilarityFunction(),
M,
beamWidth,
HnswGraphBuilder.randSeed);
createHnswGraphBuilder(mergeState, fieldInfo, vectorValues, initializerIndex);
hnswGraphBuilder.setInfoStream(segmentWriteState.infoStream);
yield hnswGraphBuilder.build(vectorValues.copy());
}
Expand All @@ -452,13 +452,7 @@ public void mergeOneField(FieldInfo fieldInfo, MergeState mergeState) throws IOE
vectorDataInput,
byteSize);
HnswGraphBuilder<float[]> hnswGraphBuilder =
HnswGraphBuilder.create(
vectorValues,
fieldInfo.getVectorEncoding(),
fieldInfo.getVectorSimilarityFunction(),
M,
beamWidth,
HnswGraphBuilder.randSeed);
createHnswGraphBuilder(mergeState, fieldInfo, vectorValues, initializerIndex);
hnswGraphBuilder.setInfoStream(segmentWriteState.infoStream);
yield hnswGraphBuilder.build(vectorValues.copy());
}
Expand Down Expand Up @@ -489,6 +483,189 @@ public void mergeOneField(FieldInfo fieldInfo, MergeState mergeState) throws IOE
}
}

private <T> HnswGraphBuilder<T> createHnswGraphBuilder(
MergeState mergeState,
FieldInfo fieldInfo,
RandomAccessVectorValues<T> floatVectorValues,
int initializerIndex)
throws IOException {
if (initializerIndex == -1) {
return HnswGraphBuilder.create(
floatVectorValues,
fieldInfo.getVectorEncoding(),
fieldInfo.getVectorSimilarityFunction(),
M,
beamWidth,
HnswGraphBuilder.randSeed);
}

HnswGraph initializerGraph =
getHnswGraphFromReader(fieldInfo.name, mergeState.knnVectorsReaders[initializerIndex]);
Map<Integer, Integer> ordinalMapper =
getOldToNewOrdinalMap(mergeState, fieldInfo, initializerIndex);
return HnswGraphBuilder.create(
floatVectorValues,
fieldInfo.getVectorEncoding(),
fieldInfo.getVectorSimilarityFunction(),
M,
beamWidth,
HnswGraphBuilder.randSeed,
initializerGraph,
ordinalMapper);
}

private int selectGraphForInitialization(MergeState mergeState, FieldInfo fieldInfo)
throws IOException {
// Find the KnnVectorReader with the most docs that meets the following criteria:
// 1. Does not contain any deleted docs
// 2. Is a Lucene95HnswVectorsReader/PerFieldKnnVectorReader
// If no readers exist that meet this criteria, return -1. If they do, return their index in
// merge state
int maxCandidateVectorCount = 0;
int initializerIndex = -1;

for (int i = 0; i < mergeState.liveDocs.length; i++) {
KnnVectorsReader currKnnVectorsReader = mergeState.knnVectorsReaders[i];
if (mergeState.knnVectorsReaders[i]
instanceof PerFieldKnnVectorsFormat.FieldsReader candidateReader) {
currKnnVectorsReader = candidateReader.getFieldReader(fieldInfo.name);
}

if (!allMatch(mergeState.liveDocs[i])
|| !(currKnnVectorsReader instanceof Lucene95HnswVectorsReader candidateReader)) {
continue;
}

int candidateVectorCount = 0;
switch (fieldInfo.getVectorEncoding()) {
case BYTE -> {
ByteVectorValues byteVectorValues = candidateReader.getByteVectorValues(fieldInfo.name);
if (byteVectorValues == null) {
continue;
}
candidateVectorCount = byteVectorValues.size();
}
case FLOAT32 -> {
FloatVectorValues vectorValues = candidateReader.getFloatVectorValues(fieldInfo.name);
if (vectorValues == null) {
continue;
}
candidateVectorCount = vectorValues.size();
}
}

if (candidateVectorCount > maxCandidateVectorCount) {
maxCandidateVectorCount = candidateVectorCount;
initializerIndex = i;
}
}
return initializerIndex;
}

private HnswGraph getHnswGraphFromReader(String fieldName, KnnVectorsReader knnVectorsReader)
throws IOException {
if (knnVectorsReader instanceof PerFieldKnnVectorsFormat.FieldsReader perFieldReader
&& perFieldReader.getFieldReader(fieldName)
instanceof Lucene95HnswVectorsReader fieldReader) {
return fieldReader.getGraph(fieldName);
}

if (knnVectorsReader instanceof Lucene95HnswVectorsReader) {
return ((Lucene95HnswVectorsReader) knnVectorsReader).getGraph(fieldName);
}

zhaih marked this conversation as resolved.
Show resolved Hide resolved
// We should not reach here because knnVectorsReader's type is checked in
// selectGraphForInitialization
throw new IllegalArgumentException(
"Invalid KnnVectorsReader type for field: "
+ fieldName
+ ". Must be Lucene95HnswVectorsReader or newer");
}

private Map<Integer, Integer> getOldToNewOrdinalMap(
MergeState mergeState, FieldInfo fieldInfo, int initializerIndex) throws IOException {

DocIdSetIterator initializerIterator = null;

switch (fieldInfo.getVectorEncoding()) {
case BYTE -> initializerIterator =
mergeState.knnVectorsReaders[initializerIndex].getByteVectorValues(fieldInfo.name);
case FLOAT32 -> initializerIterator =
mergeState.knnVectorsReaders[initializerIndex].getFloatVectorValues(fieldInfo.name);
}

MergeState.DocMap initializerDocMap = mergeState.docMaps[initializerIndex];

Map<Integer, Integer> newIdToOldOrdinal = new HashMap<>();
int oldOrd = 0;
int maxNewDocID = -1;
for (int oldId = initializerIterator.nextDoc();
oldId != NO_MORE_DOCS;
oldId = initializerIterator.nextDoc()) {
if (isCurrentVectorNull(initializerIterator)) {
continue;
}
int newId = initializerDocMap.get(oldId);
maxNewDocID = Math.max(newId, maxNewDocID);
newIdToOldOrdinal.put(newId, oldOrd);
oldOrd++;
}

if (maxNewDocID == -1) {
return Collections.emptyMap();
}

Map<Integer, Integer> oldToNewOrdinalMap = new HashMap<>();

DocIdSetIterator vectorIterator = null;
switch (fieldInfo.getVectorEncoding()) {
case BYTE -> vectorIterator = MergedVectorValues.mergeByteVectorValues(fieldInfo, mergeState);
case FLOAT32 -> vectorIterator =
MergedVectorValues.mergeFloatVectorValues(fieldInfo, mergeState);
}

int newOrd = 0;
for (int newDocId = vectorIterator.nextDoc();
newDocId <= maxNewDocID;
newDocId = vectorIterator.nextDoc()) {
if (isCurrentVectorNull(vectorIterator)) {
continue;
}

if (newIdToOldOrdinal.containsKey(newDocId)) {
oldToNewOrdinalMap.put(newIdToOldOrdinal.get(newDocId), newOrd);
}
newOrd++;
}

return oldToNewOrdinalMap;
}

private boolean isCurrentVectorNull(DocIdSetIterator docIdSetIterator) throws IOException {
if (docIdSetIterator instanceof FloatVectorValues) {
return ((FloatVectorValues) docIdSetIterator).vectorValue() == null;
}

if (docIdSetIterator instanceof ByteVectorValues) {
return ((ByteVectorValues) docIdSetIterator).vectorValue() == null;
}

return true;
}

private boolean allMatch(Bits bits) {
if (bits == null) {
return true;
}

for (int i = 0; i < bits.length(); i++) {
if (!bits.get(i)) {
return false;
}
}
return true;
}

/**
* @param graph Write the graph in a compressed format
* @return The non-cumulative offsets for the nodes. Should be used to create cumulative offsets.
Expand Down Expand Up @@ -735,10 +912,7 @@ public void addValue(int docID, T vectorValue) throws IOException {
assert docID > lastDocID;
docsWithField.add(docID);
vectors.add(copyValue(vectorValue));
if (node > 0) {
// start at node 1! node 0 is added implicitly, in the constructor
hnswGraphBuilder.addGraphNode(node, vectorValue);
}
hnswGraphBuilder.addGraphNode(node, vectorValue);
node++;
lastDocID = docID;
}
Expand Down
Loading