Skip to content

Commit

Permalink
Use HashMap (was TreeMap) for OnHeapHnswGraph neighbors
Browse files Browse the repository at this point in the history
  • Loading branch information
jbellis authored and alessandrobenedetti committed May 12, 2023
1 parent 36d6824 commit 533c6fc
Show file tree
Hide file tree
Showing 7 changed files with 146 additions and 64 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import java.util.Arrays;
import org.apache.lucene.codecs.BufferingKnnVectorsWriter;
import org.apache.lucene.codecs.CodecUtil;
import org.apache.lucene.codecs.lucene95.Lucene95HnswVectorsWriter;
import org.apache.lucene.index.ByteVectorValues;
import org.apache.lucene.index.DocsWithFieldSet;
import org.apache.lucene.index.FieldInfo;
Expand All @@ -36,7 +37,6 @@
import org.apache.lucene.store.IndexInput;
import org.apache.lucene.store.IndexOutput;
import org.apache.lucene.util.IOUtils;
import org.apache.lucene.util.hnsw.HnswGraph.NodesIterator;
import org.apache.lucene.util.hnsw.RandomAccessVectorValues;

/**
Expand Down Expand Up @@ -227,11 +227,10 @@ private void writeMeta(
} else {
meta.writeInt(graph.numLevels());
for (int level = 0; level < graph.numLevels(); level++) {
NodesIterator nodesOnLevel = graph.getNodesOnLevel(level);
meta.writeInt(nodesOnLevel.size()); // number of nodes on a level
int[] sortedNodes = Lucene95HnswVectorsWriter.getSortedNodes(graph.getNodesOnLevel(level));
meta.writeInt(sortedNodes.length); // number of nodes on a level
if (level > 0) {
while (nodesOnLevel.hasNext()) {
int node = nodesOnLevel.nextInt();
for (int node : sortedNodes) {
meta.writeInt(node); // list of nodes on a level
}
}
Expand All @@ -257,9 +256,8 @@ private Lucene91OnHeapHnswGraph writeGraph(
// write vectors' neighbours on each level into the vectorIndex file
int countOnLevel0 = graph.size();
for (int level = 0; level < graph.numLevels(); level++) {
NodesIterator nodesOnLevel = graph.getNodesOnLevel(level);
while (nodesOnLevel.hasNext()) {
int node = nodesOnLevel.nextInt();
int[] sortedNodes = Lucene95HnswVectorsWriter.getSortedNodes(graph.getNodesOnLevel(level));
for (int node : sortedNodes) {
Lucene91NeighborArray neighbors = graph.getNeighbors(level, node);
int size = neighbors.size();
vectorIndex.writeInt(size);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import org.apache.lucene.codecs.BufferingKnnVectorsWriter;
import org.apache.lucene.codecs.CodecUtil;
import org.apache.lucene.codecs.lucene90.IndexedDISI;
import org.apache.lucene.codecs.lucene95.Lucene95HnswVectorsWriter;
import org.apache.lucene.index.ByteVectorValues;
import org.apache.lucene.index.DocsWithFieldSet;
import org.apache.lucene.index.FieldInfo;
Expand All @@ -39,7 +40,6 @@
import org.apache.lucene.store.IndexInput;
import org.apache.lucene.store.IndexOutput;
import org.apache.lucene.util.IOUtils;
import org.apache.lucene.util.hnsw.HnswGraph.NodesIterator;
import org.apache.lucene.util.hnsw.HnswGraphBuilder;
import org.apache.lucene.util.hnsw.NeighborArray;
import org.apache.lucene.util.hnsw.OnHeapHnswGraph;
Expand Down Expand Up @@ -261,11 +261,10 @@ private void writeMeta(
} else {
meta.writeInt(graph.numLevels());
for (int level = 0; level < graph.numLevels(); level++) {
NodesIterator nodesOnLevel = graph.getNodesOnLevel(level);
meta.writeInt(nodesOnLevel.size()); // number of nodes on a level
int[] sortedNodes = Lucene95HnswVectorsWriter.getSortedNodes(graph.getNodesOnLevel(level));
meta.writeInt(sortedNodes.length); // number of nodes on a level
if (level > 0) {
while (nodesOnLevel.hasNext()) {
int node = nodesOnLevel.nextInt();
for (int node : sortedNodes) {
meta.writeInt(node); // list of nodes on a level
}
}
Expand Down Expand Up @@ -293,9 +292,8 @@ private OnHeapHnswGraph writeGraph(
int countOnLevel0 = graph.size();
for (int level = 0; level < graph.numLevels(); level++) {
int maxConnOnLevel = level == 0 ? (M * 2) : M;
NodesIterator nodesOnLevel = graph.getNodesOnLevel(level);
while (nodesOnLevel.hasNext()) {
int node = nodesOnLevel.nextInt();
int[] sortedNodes = Lucene95HnswVectorsWriter.getSortedNodes(graph.getNodesOnLevel(level));
for (int node : sortedNodes) {
NeighborArray neighbors = graph.getNeighbors(level, node);
int size = neighbors.size();
vectorIndex.writeInt(size);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
import org.apache.lucene.codecs.KnnFieldVectorsWriter;
import org.apache.lucene.codecs.KnnVectorsWriter;
import org.apache.lucene.codecs.lucene90.IndexedDISI;
import org.apache.lucene.codecs.lucene95.Lucene95HnswVectorsWriter;
import org.apache.lucene.index.ByteVectorValues;
import org.apache.lucene.index.DocsWithFieldSet;
import org.apache.lucene.index.FieldInfo;
Expand Down Expand Up @@ -303,9 +304,8 @@ private HnswGraph reconstructAndWriteGraph(
for (int level = 1; level < graph.numLevels(); level++) {
NodesIterator nodesOnLevel = graph.getNodesOnLevel(level);
int[] newNodes = new int[nodesOnLevel.size()];
int n = 0;
while (nodesOnLevel.hasNext()) {
newNodes[n++] = oldToNewMap[nodesOnLevel.nextInt()];
for (int n = 0; nodesOnLevel.hasNext(); n++) {
newNodes[n] = oldToNewMap[nodesOnLevel.nextInt()];
}
Arrays.sort(newNodes);
nodesByLevel.add(newNodes);
Expand Down Expand Up @@ -481,9 +481,8 @@ private void writeGraph(OnHeapHnswGraph graph) throws IOException {
int countOnLevel0 = graph.size();
for (int level = 0; level < graph.numLevels(); level++) {
int maxConnOnLevel = level == 0 ? (M * 2) : M;
NodesIterator nodesOnLevel = graph.getNodesOnLevel(level);
while (nodesOnLevel.hasNext()) {
int node = nodesOnLevel.nextInt();
int[] sortedNodes = Lucene95HnswVectorsWriter.getSortedNodes(graph.getNodesOnLevel(level));
for (int node : sortedNodes) {
NeighborArray neighbors = graph.getNeighbors(level, node);
int size = neighbors.size();
vectorIndex.writeInt(size);
Expand Down Expand Up @@ -570,11 +569,10 @@ private void writeMeta(
} else {
meta.writeInt(graph.numLevels());
for (int level = 0; level < graph.numLevels(); level++) {
NodesIterator nodesOnLevel = graph.getNodesOnLevel(level);
meta.writeInt(nodesOnLevel.size()); // number of nodes on a level
int[] sortedNodes = Lucene95HnswVectorsWriter.getSortedNodes(graph.getNodesOnLevel(level));
meta.writeInt(sortedNodes.length); // number of nodes on a level
if (level > 0) {
while (nodesOnLevel.hasNext()) {
int node = nodesOnLevel.nextInt();
for (int node : sortedNodes) {
meta.writeInt(node); // list of nodes on a level
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -315,9 +315,8 @@ private HnswGraph reconstructAndWriteGraph(
for (int level = 1; level < graph.numLevels(); level++) {
NodesIterator nodesOnLevel = graph.getNodesOnLevel(level);
int[] newNodes = new int[nodesOnLevel.size()];
int n = 0;
while (nodesOnLevel.hasNext()) {
newNodes[n++] = oldToNewMap[nodesOnLevel.nextInt()];
for (int n = 0; nodesOnLevel.hasNext(); n++) {
newNodes[n] = oldToNewMap[nodesOnLevel.nextInt()];
}
Arrays.sort(newNodes);
nodesByLevel.add(newNodes);
Expand Down Expand Up @@ -677,11 +676,10 @@ private int[][] writeGraph(OnHeapHnswGraph graph) throws IOException {
int countOnLevel0 = graph.size();
int[][] offsets = new int[graph.numLevels()][];
for (int level = 0; level < graph.numLevels(); level++) {
NodesIterator nodesOnLevel = graph.getNodesOnLevel(level);
offsets[level] = new int[nodesOnLevel.size()];
int[] sortedNodes = getSortedNodes(graph.getNodesOnLevel(level));
offsets[level] = new int[sortedNodes.length];
int nodeOffsetId = 0;
while (nodesOnLevel.hasNext()) {
int node = nodesOnLevel.nextInt();
for (int node : sortedNodes) {
NeighborArray neighbors = graph.getNeighbors(level, node);
int size = neighbors.size();
// Write size in VInt as the neighbors list is typically small
Expand All @@ -706,6 +704,15 @@ private int[][] writeGraph(OnHeapHnswGraph graph) throws IOException {
return offsets;
}

public static int[] getSortedNodes(NodesIterator nodesOnLevel) {
int[] sortedNodes = new int[nodesOnLevel.size()];
for (int n = 0; nodesOnLevel.hasNext(); n++) {
sortedNodes[n] = nodesOnLevel.nextInt();
}
Arrays.sort(sortedNodes);
return sortedNodes;
}

private void writeMeta(
FieldInfo field,
int maxDoc,
Expand Down Expand Up @@ -779,6 +786,7 @@ private void writeMeta(
if (level > 0) {
int[] nol = new int[nodesOnLevel.size()];
int numberConsumed = nodesOnLevel.consume(nol);
Arrays.sort(nol);
assert numberConsumed == nodesOnLevel.size();
meta.writeVInt(nol.length); // number of nodes on a level
for (int i = nodesOnLevel.size() - 1; i > 0; --i) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,8 @@ protected HnswGraph() {}
public abstract int entryNode() throws IOException;

/**
* Get all nodes on a given level as node 0th ordinals
* Get all nodes on a given level as node 0th ordinals. The nodes are NOT guaranteed to be
* presented in any particular order.
*
* @param level level for which to get all nodes
* @return an iterator over nodes where {@code nextInt} returns a next node on the level
Expand Down Expand Up @@ -123,7 +124,8 @@ public NodesIterator getNodesOnLevel(int level) {

/**
* Iterator over the graph nodes on a certain level, Iterator also provides the size – the total
* number of nodes to be iterated over.
* number of nodes to be iterated over. The nodes are NOT guaranteed to be presented in any
* particular order.
*/
public abstract static class NodesIterator implements PrimitiveIterator.OfInt {
protected final int size;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,9 @@
import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.TreeMap;
import java.util.Map;
import org.apache.lucene.util.Accountable;
import org.apache.lucene.util.RamUsageEstimator;

Expand All @@ -40,12 +41,12 @@ public final class OnHeapHnswGraph extends HnswGraph implements Accountable {
// added to HnswBuilder, and the node values are the ordinals of those vectors.
// Thus, on all levels, neighbors expressed as the level 0's nodes' ordinals.
private final List<NeighborArray> graphLevel0;
// Represents levels 1-N. Each level is represented with a TreeMap that maps a levels level 0
// Represents levels 1-N. Each level is represented with a Map that maps a levels level 0
// ordinal to its neighbors on that level. All nodes are in level 0, so we do not need to maintain
// it in this list. However, to avoid changing list indexing, we always will make the first
// element
// null.
private final List<TreeMap<Integer, NeighborArray>> graphUpperLevels;
private final List<Map<Integer, NeighborArray>> graphUpperLevels;
private final int nsize;
private final int nsize0;

Expand Down Expand Up @@ -76,7 +77,7 @@ public NeighborArray getNeighbors(int level, int node) {
if (level == 0) {
return graphLevel0.get(node);
}
TreeMap<Integer, NeighborArray> levelMap = graphUpperLevels.get(level);
Map<Integer, NeighborArray> levelMap = graphUpperLevels.get(level);
assert levelMap.containsKey(node);
return levelMap.get(node);
}
Expand All @@ -103,7 +104,7 @@ public void addNode(int level, int node) {
// and make this node the graph's new entry point
if (level >= numLevels) {
for (int i = numLevels; i <= level; i++) {
graphUpperLevels.add(new TreeMap<>());
graphUpperLevels.add(new HashMap<>());
}
numLevels = level + 1;
entryNode = node;
Expand Down Expand Up @@ -204,4 +205,15 @@ public long ramBytesUsed() {
}
return total;
}

@Override
public String toString() {
return "OnHeapHnswGraph(size="
+ size()
+ ", numLevels="
+ numLevels
+ ", entryNode="
+ entryNode
+ ")";
}
}
Loading

0 comments on commit 533c6fc

Please sign in to comment.