Skip to content

Commit

Permalink
Use TreeMap for graph structure for levels > 0
Browse files Browse the repository at this point in the history
Refactors OnHeapHnswGraph to use TreeMap to represent graph structure of
levels greater than 0. Refactors NodesIterator to support set
representation of nodes.

Signed-off-by: John Mazanec <jmazane@amazon.com>
  • Loading branch information
jmazanec15 committed Jan 30, 2023
1 parent 4cc64ae commit e6b8a07
Show file tree
Hide file tree
Showing 9 changed files with 173 additions and 127 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -561,9 +561,9 @@ public int entryNode() {
@Override
public NodesIterator getNodesOnLevel(int level) {
if (level == 0) {
return new NodesIterator(size());
return new ArrayNodesIterator(size());
} else {
return new NodesIterator(nodesByLevel[level], nodesByLevel[level].length);
return new ArrayNodesIterator(nodesByLevel[level], nodesByLevel[level].length);
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -163,9 +163,9 @@ public int entryNode() {
@Override
public NodesIterator getNodesOnLevel(int level) {
if (level == 0) {
return new NodesIterator(size());
return new ArrayNodesIterator(size());
} else {
return new NodesIterator(nodesByLevel.get(level), graph.get(level).size());
return new ArrayNodesIterator(nodesByLevel.get(level), graph.get(level).size());
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -457,9 +457,9 @@ public int entryNode() {
@Override
public NodesIterator getNodesOnLevel(int level) {
if (level == 0) {
return new NodesIterator(size());
return new ArrayNodesIterator(size());
} else {
return new NodesIterator(nodesByLevel[level], nodesByLevel[level].length);
return new ArrayNodesIterator(nodesByLevel[level], nodesByLevel[level].length);
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -533,9 +533,9 @@ public int entryNode() {
@Override
public NodesIterator getNodesOnLevel(int level) {
if (level == 0) {
return new NodesIterator(size());
return new ArrayNodesIterator(size());
} else {
return new NodesIterator(nodesByLevel[level], nodesByLevel[level].length);
return new ArrayNodesIterator(nodesByLevel[level], nodesByLevel[level].length);
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -345,7 +345,7 @@ public NodesIterator getNodesOnLevel(int level) {
if (level == 0) {
return graph.getNodesOnLevel(0);
} else {
return new NodesIterator(nodesByLevel.get(level), nodesByLevel.get(level).length);
return new ArrayNodesIterator(nodesByLevel.get(level), nodesByLevel.get(level).length);
}
}
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -573,9 +573,9 @@ public int entryNode() throws IOException {
@Override
public NodesIterator getNodesOnLevel(int level) {
if (level == 0) {
return new NodesIterator(size());
return new ArrayNodesIterator(size());
} else {
return new NodesIterator(nodesByLevel[level], nodesByLevel[level].length);
return new ArrayNodesIterator(nodesByLevel[level], nodesByLevel[level].length);
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -362,7 +362,7 @@ public NodesIterator getNodesOnLevel(int level) {
if (level == 0) {
return graph.getNodesOnLevel(0);
} else {
return new NodesIterator(nodesByLevel.get(level), nodesByLevel.get(level).length);
return new ArrayNodesIterator(nodesByLevel.get(level), nodesByLevel.get(level).length);
}
}
};
Expand Down Expand Up @@ -533,12 +533,24 @@ private int selectGraphForInitialization(MergeState mergeState, FieldInfo fieldI
continue;
}

VectorValues vectorValues = candidateReader.getVectorValues(fieldInfo.name);
if (vectorValues == null) {
continue;
int candidateVectorCount = 0;
switch (fieldInfo.getVectorEncoding()) {
case BYTE -> {
ByteVectorValues byteVectorValues = candidateReader.getByteVectorValues(fieldInfo.name);
if (byteVectorValues == null) {
continue;
}
candidateVectorCount = byteVectorValues.size();
}
case FLOAT32 -> {
FloatVectorValues vectorValues = candidateReader.getFloatVectorValues(fieldInfo.name);
if (vectorValues == null) {
continue;
}
candidateVectorCount = vectorValues.size();
}
}

int candidateVectorCount = vectorValues.size();
if (candidateVectorCount > maxCandidateVectorCount) {
maxCandidateVectorCount = candidateVectorCount;
initializerIndex = i;
Expand Down Expand Up @@ -569,17 +581,25 @@ private HnswGraph getHnswGraphFromReader(String fieldName, KnnVectorsReader knnV

private Map<Integer, Integer> getOldToNewOrdinalMap(
MergeState mergeState, FieldInfo fieldInfo, int initializerIndex) throws IOException {
VectorValues initializerVectorValues =
mergeState.knnVectorsReaders[initializerIndex].getVectorValues(fieldInfo.name);

DocIdSetIterator initializerIterator = null;

switch (fieldInfo.getVectorEncoding()) {
case BYTE -> initializerIterator =
mergeState.knnVectorsReaders[initializerIndex].getByteVectorValues(fieldInfo.name);
case FLOAT32 -> initializerIterator =
mergeState.knnVectorsReaders[initializerIndex].getFloatVectorValues(fieldInfo.name);
}

MergeState.DocMap initializerDocMap = mergeState.docMaps[initializerIndex];

Map<Integer, Integer> newIdToOldOrdinal = new HashMap<>();
int oldOrd = 0;
int maxNewDocID = -1;
for (int oldId = initializerVectorValues.nextDoc();
for (int oldId = initializerIterator.nextDoc();
oldId != NO_MORE_DOCS;
oldId = initializerVectorValues.nextDoc()) {
if (initializerVectorValues.vectorValue() == null) {
oldId = initializerIterator.nextDoc()) {
if (isCurrentVectorNull(initializerIterator)) {
continue;
}
int newId = initializerDocMap.get(oldId);
Expand All @@ -593,12 +613,19 @@ private Map<Integer, Integer> getOldToNewOrdinalMap(
}

Map<Integer, Integer> oldToNewOrdinalMap = new HashMap<>();
VectorValues vectorValues = MergedVectorValues.mergeVectorValues(fieldInfo, mergeState);

DocIdSetIterator vectorIterator = null;
switch (fieldInfo.getVectorEncoding()) {
case BYTE -> vectorIterator = MergedVectorValues.mergeByteVectorValues(fieldInfo, mergeState);
case FLOAT32 -> vectorIterator =
MergedVectorValues.mergeFloatVectorValues(fieldInfo, mergeState);
}

int newOrd = 0;
for (int newDocId = vectorValues.nextDoc();
for (int newDocId = vectorIterator.nextDoc();
newDocId <= maxNewDocID;
newDocId = vectorValues.nextDoc()) {
if (vectorValues.vectorValue() == null) {
newDocId = vectorIterator.nextDoc()) {
if (isCurrentVectorNull(vectorIterator)) {
continue;
}

Expand All @@ -611,6 +638,18 @@ private Map<Integer, Integer> getOldToNewOrdinalMap(
return oldToNewOrdinalMap;
}

private boolean isCurrentVectorNull(DocIdSetIterator docIdSetIterator) throws IOException {
if (docIdSetIterator instanceof FloatVectorValues) {
return ((FloatVectorValues) docIdSetIterator).vectorValue() == null;
}

if (docIdSetIterator instanceof ByteVectorValues) {
return ((ByteVectorValues) docIdSetIterator).vectorValue() == null;
}

return true;
}

private boolean allMatch(Bits bits) {
if (bits == null) {
return true;
Expand Down
90 changes: 71 additions & 19 deletions lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraph.java
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,10 @@
import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;

import java.io.IOException;
import java.util.Iterator;
import java.util.NoSuchElementException;
import java.util.PrimitiveIterator;
import java.util.Set;
import org.apache.lucene.index.FloatVectorValues;

/**
Expand Down Expand Up @@ -115,41 +117,58 @@ public int entryNode() {

@Override
public NodesIterator getNodesOnLevel(int level) {
return NodesIterator.EMPTY;
return ArrayNodesIterator.EMPTY;
}
};

/**
* Iterator over the graph nodes on a certain level, Iterator also provides the size – the total
* number of nodes to be iterated over.
*/
public static final class NodesIterator implements PrimitiveIterator.OfInt {
static NodesIterator EMPTY = new NodesIterator(0);

private final int[] nodes;
private final int size;
int cur = 0;

/** Constructor for iterator based on the nodes array up to the size */
public NodesIterator(int[] nodes, int size) {
assert nodes != null;
assert size <= nodes.length;
this.nodes = nodes;
this.size = size;
}
public abstract static class NodesIterator implements PrimitiveIterator.OfInt {
protected final int size;

/** Constructor for iterator based on the size */
public NodesIterator(int size) {
this.nodes = null;
this.size = size;
}

/** The number of elements in this iterator * */
public int size() {
return size;
}

/**
* Consume integers from the iterator and place them into the `dest` array.
*
* @param dest where to put the integers
* @return The number of integers written to `dest`
*/
public abstract int consume(int[] dest);
}

/** NodesIterator that accepts nodes as an integer array. */
public static class ArrayNodesIterator extends NodesIterator {
static NodesIterator EMPTY = new ArrayNodesIterator(0);

private final int[] nodes;
private int cur = 0;

/** Constructor for iterator based on integer array representing nodes */
public ArrayNodesIterator(int[] nodes, int size) {
super(size);
assert nodes != null;
assert size <= nodes.length;
this.nodes = nodes;
}

/** Constructor for iterator based on the size */
public ArrayNodesIterator(int size) {
super(size);
this.nodes = null;
}

@Override
public int consume(int[] dest) {
if (hasNext() == false) {
throw new NoSuchElementException();
Expand Down Expand Up @@ -182,10 +201,43 @@ public int nextInt() {
public boolean hasNext() {
return cur < size;
}
}

/** The number of elements in this iterator * */
public int size() {
return size;
/** Nodes iterator based on set representation of nodes. */
public static class SetNodesIterator extends NodesIterator {
Iterator<Integer> nodes;

/** Constructor for iterator based on set representing nodes */
public SetNodesIterator(Set<Integer> nodes) {
super(nodes.size());
this.nodes = nodes.iterator();
}

@Override
public int consume(int[] dest) {
if (hasNext() == false) {
throw new NoSuchElementException();
}

int destIndex = 0;
while (hasNext() && destIndex < dest.length) {
dest[destIndex++] = nextInt();
}

return destIndex;
}

@Override
public int nextInt() {
if (hasNext() == false) {
throw new NoSuchElementException();
}
return nodes.next();
}

@Override
public boolean hasNext() {
return nodes.hasNext();
}
}
}
Loading

0 comments on commit e6b8a07

Please sign in to comment.