Skip to content

Commit

Permalink
Utilize merge with graph init in HNSWWriter
Browse files Browse the repository at this point in the history
Uses HNSWGraphBuilder initialization from graph functionality in
Lucene95HnswVectorsWriter. Selects the largest graph to initialize the
new graph produced by the HNSWGraphBuilder for merge.

Signed-off-by: John Mazanec <jmazane@amazon.com>
  • Loading branch information
jmazanec15 committed Jan 30, 2023
1 parent 447237d commit ca6861f
Show file tree
Hide file tree
Showing 2 changed files with 127 additions and 84 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,16 @@
import java.nio.ByteOrder;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import org.apache.lucene.codecs.CodecUtil;
import org.apache.lucene.codecs.KnnFieldVectorsWriter;
import org.apache.lucene.codecs.KnnVectorsReader;
import org.apache.lucene.codecs.KnnVectorsWriter;
import org.apache.lucene.codecs.lucene90.IndexedDISI;
import org.apache.lucene.codecs.perfield.PerFieldKnnVectorsFormat;
import org.apache.lucene.index.*;
import org.apache.lucene.index.Sorter;
import org.apache.lucene.search.DocIdSetIterator;
Expand Down Expand Up @@ -442,6 +447,7 @@ public void mergeOneField(FieldInfo fieldInfo, MergeState mergeState) throws IOE
beamWidth,
HnswGraphBuilder.randSeed);
hnswGraphBuilder.setInfoStream(segmentWriteState.infoStream);
maybeInitializeFromGraph(hnswGraphBuilder, mergeState, fieldInfo);
yield hnswGraphBuilder.build(vectorValues.copy());
}
case FLOAT32 -> {
Expand All @@ -460,6 +466,7 @@ public void mergeOneField(FieldInfo fieldInfo, MergeState mergeState) throws IOE
beamWidth,
HnswGraphBuilder.randSeed);
hnswGraphBuilder.setInfoStream(segmentWriteState.infoStream);
maybeInitializeFromGraph(hnswGraphBuilder, mergeState, fieldInfo);
yield hnswGraphBuilder.build(vectorValues.copy());
}
};
Expand Down Expand Up @@ -489,6 +496,126 @@ public void mergeOneField(FieldInfo fieldInfo, MergeState mergeState) throws IOE
}
}

private void maybeInitializeFromGraph(
HnswGraphBuilder<?> hnswGraphBuilder, MergeState mergeState, FieldInfo fieldInfo)
throws IOException {
int initializerIndex = selectGraphForInitialization(mergeState, fieldInfo);
if (initializerIndex == -1) {
return;
}

HnswGraph initializerGraph =
getHnswGraphFromReader(fieldInfo.name, mergeState.knnVectorsReaders[initializerIndex]);
Map<Integer, Integer> ordinalMapper =
getOldToNewOrdinalMap(mergeState, fieldInfo, initializerIndex);
hnswGraphBuilder.initializeFromGraph(initializerGraph, ordinalMapper);
}

private int selectGraphForInitialization(MergeState mergeState, FieldInfo fieldInfo)
throws IOException {
// Find the KnnVectorReader with the most docs that meets the following criteria:
// 1. Does not contain any deleted docs
// 2. Is a Lucene95HnswVectorsReader/PerFieldKnnVectorReader
// If no readers exist that meet this criteria, return -1. If they do, return their index in
// merge state
int maxCandidateVectorCount = 0;
int initializerIndex = -1;

for (int i = 0; i < mergeState.liveDocs.length; i++) {
KnnVectorsReader currKnnVectorsReader = mergeState.knnVectorsReaders[i];
if (mergeState.knnVectorsReaders[i]
instanceof PerFieldKnnVectorsFormat.FieldsReader candidateReader) {
currKnnVectorsReader = candidateReader.getFieldReader(fieldInfo.name);
}

if (!allMatch(mergeState.liveDocs[i])
|| !(currKnnVectorsReader instanceof Lucene95HnswVectorsReader candidateReader)) {
continue;
}

VectorValues vectorValues = candidateReader.getVectorValues(fieldInfo.name);
if (vectorValues == null) {
continue;
}

int candidateVectorCount = vectorValues.size();
if (candidateVectorCount > maxCandidateVectorCount) {
maxCandidateVectorCount = candidateVectorCount;
initializerIndex = i;
}
}
return initializerIndex;
}

private HnswGraph getHnswGraphFromReader(String fieldName, KnnVectorsReader knnVectorsReader)
throws IOException {
if (knnVectorsReader instanceof PerFieldKnnVectorsFormat.FieldsReader perFieldReader
&& perFieldReader.getFieldReader(fieldName)
instanceof Lucene95HnswVectorsReader fieldReader) {
return fieldReader.getGraph(fieldName);
}

if (knnVectorsReader instanceof Lucene95HnswVectorsReader) {
return ((Lucene95HnswVectorsReader) knnVectorsReader).getGraph(fieldName);
}

throw new IllegalArgumentException(
"Invalid KnnVectorsReader. Must be of type PerFieldKnnVectorsFormat.FieldsReader or Lucene94HnswVectorsReader");
}

private Map<Integer, Integer> getOldToNewOrdinalMap(
MergeState mergeState, FieldInfo fieldInfo, int initializerIndex) throws IOException {
VectorValues initializerVectorValues =
mergeState.knnVectorsReaders[initializerIndex].getVectorValues(fieldInfo.name);
MergeState.DocMap initializerDocMap = mergeState.docMaps[initializerIndex];

Map<Integer, Integer> newIdToOldOrdinal = new HashMap<>();
int oldOrd = 0;
for (int oldId = initializerVectorValues.nextDoc();
oldId != NO_MORE_DOCS;
oldId = initializerVectorValues.nextDoc()) {
if (initializerVectorValues.vectorValue() == null) {
continue;
}
int newId = initializerDocMap.get(oldId);
newIdToOldOrdinal.put(newId, oldOrd);
oldOrd++;
}

Map<Integer, Integer> oldToNewOrdinalMap = new HashMap<>();
int newOrd = 0;
int maxNewDocID = Collections.max(newIdToOldOrdinal.keySet());
VectorValues vectorValues = MergedVectorValues.mergeVectorValues(fieldInfo, mergeState);

for (int newDocId = vectorValues.nextDoc();
newDocId <= maxNewDocID;
newDocId = vectorValues.nextDoc()) {
if (vectorValues.vectorValue() == null) {
continue;
}

if (newIdToOldOrdinal.containsKey(newDocId)) {
oldToNewOrdinalMap.put(newIdToOldOrdinal.get(newDocId), newOrd);
}
newOrd++;
}

return oldToNewOrdinalMap;
}

private boolean allMatch(Bits bits) {
if (bits == null) {
return true;
}

for (int i = 0; i < bits.length(); i++) {
if (!bits.get(i)) {
return false;
}
}
return true;
}

/**
* @param graph Write the graph in a compressed format
* @return The non-cumulative offsets for the nodes. Should be used to create cumulative offsets.
Expand Down
84 changes: 0 additions & 84 deletions lucene/core/src/test/org/apache/lucene/index/TestKnnGraph.java
Original file line number Diff line number Diff line change
Expand Up @@ -48,14 +48,12 @@
import org.apache.lucene.search.TopDocs;
import org.apache.lucene.store.Directory;
import org.apache.lucene.tests.util.LuceneTestCase;
import org.apache.lucene.util.ArrayUtil;
import org.apache.lucene.util.Bits;
import org.apache.lucene.util.BytesRef;
import org.apache.lucene.util.IOUtils;
import org.apache.lucene.util.VectorUtil;
import org.apache.lucene.util.hnsw.HnswGraph;
import org.apache.lucene.util.hnsw.HnswGraph.NodesIterator;
import org.apache.lucene.util.hnsw.HnswGraphBuilder;
import org.junit.After;
import org.junit.Before;

Expand Down Expand Up @@ -179,21 +177,6 @@ public void testMerge() throws Exception {
}
}

/**
* Verify that we get the *same* graph by indexing one segment as we do by indexing two segments
* and merging.
*/
public void testMergeProducesSameGraph() throws Exception {
long seed = random().nextLong();
int numDoc = atLeast(100);
int dimension = atLeast(10);
float[][] values = randomVectors(numDoc, dimension);
int mergePoint = random().nextInt(numDoc);
int[][][] mergedGraph = getIndexedGraph(values, mergePoint, seed);
int[][][] singleSegmentGraph = getIndexedGraph(values, -1, seed);
assertGraphEquals(singleSegmentGraph, mergedGraph);
}

/** Test writing and reading of multiple vector fields * */
public void testMultipleVectorFields() throws Exception {
int numVectorFields = randomIntBetween(2, 5);
Expand Down Expand Up @@ -227,52 +210,6 @@ public void testMultipleVectorFields() throws Exception {
}
}

private void assertGraphEquals(int[][][] expected, int[][][] actual) {
assertEquals("graph sizes differ", expected.length, actual.length);
for (int level = 0; level < expected.length; level++) {
for (int node = 0; node < expected[level].length; node++) {
assertArrayEquals("difference at ord=" + node, expected[level][node], actual[level][node]);
}
}
}

/**
* Return a naive representation of an HNSW graph as a 3 dimensional array: 1st dim represents a
* graph layer. Each layer contains an array of arrays – a list of nodes and for each node a list
* of the node's neighbours. 2nd dim represents a node on a layer, and contains the node's
* neighbourhood, or {@code null} if a node is not present on this layer. 3rd dim represents
* neighbours of a node.
*/
private int[][][] getIndexedGraph(float[][] values, int mergePoint, long seed)
throws IOException {
HnswGraphBuilder.randSeed = seed;
int[][][] graph;
try (Directory dir = newDirectory()) {
IndexWriterConfig iwc = newIndexWriterConfig();
iwc.setMergePolicy(new LogDocMergePolicy()); // for predictable segment ordering when merging
iwc.setCodec(codec); // don't use SimpleTextCodec
try (IndexWriter iw = new IndexWriter(dir, iwc)) {
for (int i = 0; i < values.length; i++) {
add(iw, i, values[i]);
if (i == mergePoint) {
// flush proactively to create a segment
iw.flush();
}
}
iw.forceMerge(1);
}
try (IndexReader reader = DirectoryReader.open(dir)) {
PerFieldKnnVectorsFormat.FieldsReader perFieldReader =
(PerFieldKnnVectorsFormat.FieldsReader)
((CodecReader) getOnlyLeafReader(reader)).getVectorReader();
Lucene95HnswVectorsReader vectorReader =
(Lucene95HnswVectorsReader) perFieldReader.getFieldReader(KNN_GRAPH_FIELD);
graph = copyGraph(vectorReader.getGraph(KNN_GRAPH_FIELD));
}
}
return graph;
}

private float[][] randomVectors(int numDoc, int dimension) {
float[][] values = new float[numDoc][];
for (int i = 0; i < numDoc; i++) {
Expand All @@ -297,27 +234,6 @@ private float[] randomVector(int dimension) {
return value;
}

int[][][] copyGraph(HnswGraph graphValues) throws IOException {
int[][][] graph = new int[graphValues.numLevels()][][];
int size = graphValues.size();
int[] scratch = new int[M * 2];

for (int level = 0; level < graphValues.numLevels(); level++) {
NodesIterator nodesItr = graphValues.getNodesOnLevel(level);
graph[level] = new int[size][];
while (nodesItr.hasNext()) {
int node = nodesItr.nextInt();
graphValues.seek(level, node);
int n, count = 0;
while ((n = graphValues.nextNeighbor()) != NO_MORE_DOCS) {
scratch[count++] = n;
}
graph[level][node] = ArrayUtil.copyOfSubArray(scratch, 0, count);
}
}
return graph;
}

/** Verify that searching does something reasonable */
public void testSearch() throws Exception {
// We can't use dot product here since the vectors are laid out on a grid, not a sphere.
Expand Down

0 comments on commit ca6861f

Please sign in to comment.